forked from mindspore-Ecosystem/mindspore
Vjp and fix jvp
This commit is contained in:
parent
12ced9e89b
commit
4f4a344149
|
@ -18,7 +18,7 @@ Grad
|
|||
Cells of grad function. Calculate the gradient of input network or function.
|
||||
"""
|
||||
|
||||
from .cell_grad import Jvp
|
||||
from .cell_grad import Jvp, Vjp
|
||||
|
||||
|
||||
__all__ = ['Jvp']
|
||||
__all__ = ['Jvp', 'Vjp']
|
||||
|
|
|
@ -15,8 +15,8 @@
|
|||
"""cell grad"""
|
||||
from ..cell import Cell
|
||||
from ...ops import composite as C
|
||||
from ...ops.primitive import Primitive
|
||||
from ...ops import operations as P
|
||||
from ...ops.primitive import Primitive
|
||||
from ...common import dtype as mstype
|
||||
from ...common.api import ms_function
|
||||
|
||||
|
@ -73,9 +73,9 @@ class Jvp(Cell):
|
|||
self.tuple_len = Primitive("tuple_len")
|
||||
|
||||
@ms_function
|
||||
def construct(self, *total_input):
|
||||
jvp_input = total_input[0:-1]
|
||||
v = total_input[-1]
|
||||
def construct(self, *args):
|
||||
jvp_input = args[0:-1]
|
||||
v = args[-1]
|
||||
output = self.fn(*jvp_input)
|
||||
|
||||
if self.issubclass_(self.typeof(output), mstype.tuple_):
|
||||
|
@ -92,3 +92,42 @@ class Jvp(Cell):
|
|||
second_gradient_net = self.second_grad_op(self.first_grad)
|
||||
gradient_output = second_gradient_net(u, jvp_input, v)
|
||||
return output, gradient_output
|
||||
|
||||
|
||||
class Vjp(Cell):
|
||||
"""
|
||||
Computes the dot product between a vector `v` and the Jacobian of the given network at the point
|
||||
given by the inputs.
|
||||
|
||||
Args:
|
||||
network (Cell): The network that takes Tensor inputs and returns a tuple of Tensors or a Tensor.
|
||||
|
||||
Inputs:
|
||||
- **inputs** (Tensors) - The inputs to `net`. Must be a tuple or a list.
|
||||
- **v** (tuple of Tensors or Tensor) - The vector for which the vector Jacobian product is computed.
|
||||
Must have the same size as the output of `network`.
|
||||
|
||||
Outputs:
|
||||
A tuple with:
|
||||
net_output (Tuple(Tensor...)) - The output of `network(inputs)`.
|
||||
vjp (Tuple(Tensor...)) - The result of the dot product.
|
||||
"""
|
||||
|
||||
def __init__(self, fn):
|
||||
super(Vjp, self).__init__()
|
||||
self.fn = fn
|
||||
self.grad = C.GradOperation(get_all=True, sens_param=True)
|
||||
self.grad_single_value = C.GradOperation(sens_param=True)
|
||||
self.issubclass_ = P.IsSubClass()
|
||||
self.typeof = Primitive('typeof')
|
||||
self.tuple_len = Primitive("tuple_len")
|
||||
|
||||
@ms_function
|
||||
def construct(self, *args):
|
||||
front_input = args[0:-1]
|
||||
output = self.fn(*front_input)
|
||||
if self.tuple_len(front_input) == 1:
|
||||
gradient_output = self.grad_single_value(self.fn)(*args)
|
||||
else:
|
||||
gradient_output = self.grad(self.fn)(*args)
|
||||
return output, gradient_output
|
||||
|
|
|
@ -20,8 +20,6 @@
|
|||
from mindspore.common._register_for_tensor import tensor_operator_registry
|
||||
from mindspore.ops import _constants
|
||||
from .primitive import Primitive
|
||||
from ..common import Tensor
|
||||
from ..nn.grad import Jvp
|
||||
from . import operations as P
|
||||
from .operations import _grad_ops
|
||||
from .composite import GradOperation
|
||||
|
@ -156,63 +154,6 @@ identity = P.identity()
|
|||
grad_first_parameter = GradOperation(get_all=False, get_by_list=False, sens_param=False)
|
||||
grad_all_parameters = GradOperation(get_all=True, get_by_list=False, sens_param=False)
|
||||
|
||||
def _to_tuple(inp, arg_name):
|
||||
"""
|
||||
Check whether input to jvp is valid and convert the input to tuple.
|
||||
"""
|
||||
if isinstance(inp, list):
|
||||
inp = tuple(inp)
|
||||
inp_is_tuple = True
|
||||
if not isinstance(inp, tuple):
|
||||
inp = (inp,)
|
||||
inp_is_tuple = False
|
||||
for index, value in enumerate(inp):
|
||||
if not isinstance(value, Tensor):
|
||||
if inp_is_tuple:
|
||||
raise TypeError(
|
||||
"The value at the index {} of {} must be a Tensor. But it's type is {}".format(index, arg_name,
|
||||
type(value)))
|
||||
raise TypeError("The {} must be a Tensor. But it's type is {}".format(arg_name, type(value)))
|
||||
return inp
|
||||
|
||||
|
||||
def jvp(fn, jvp_input, v=None):
|
||||
"""
|
||||
Function to compute the jacobian-vector-product of the given network. jvp is equivalent to forward mode autodiff.
|
||||
|
||||
Inputs:
|
||||
- **network** (Cell): The network that takes Tensor inputs and returns a tuple of Tensors or a Tensor.
|
||||
- **inputs** (Tensor or tuple/list of Tensors) - The inputs to `net`.
|
||||
- **v** (tuple/list of Tensors or Tensor) - The vector for which the Jacobian vector product is computed.
|
||||
Must have the same size as the input of `network`.
|
||||
|
||||
Outputs:
|
||||
A tuple with:
|
||||
net_output (Tuple(Tensor...)) - The output of `network(inputs)`.
|
||||
jvp (Tuple(Tensor...)) - The result of the jacobian vector product.
|
||||
"""
|
||||
inputs_tuple = _to_tuple(jvp_input, "input")
|
||||
if v is not None:
|
||||
v_tuple = _to_tuple(v, "v")
|
||||
if len(v_tuple) != len(inputs_tuple):
|
||||
raise ValueError("v is not the same size as the function inputs. The length of v is {}, while the length"
|
||||
"of function inputs is {}".format(len(v_tuple), len(inputs_tuple)))
|
||||
for index, (v_value, inputs_value) in enumerate(zip(v_tuple, inputs_tuple)):
|
||||
if v_value.shape != inputs_value.shape:
|
||||
raise ValueError("The tensor shape at the index {} of v is not the same as the function inputs, it"
|
||||
"should be {}, but got {}".format(index, inputs_value.shape, v_value.shape))
|
||||
else:
|
||||
if len(inputs_tuple) != 1 or inputs_tuple[0].shape != (1,):
|
||||
raise ValueError("The vector v can only be None if the input is a Tensor with single element")
|
||||
v = Tensor([1], dtype=inputs_tuple[0].dtype)
|
||||
|
||||
if len(inputs_tuple) != 1:
|
||||
total_input = (*inputs_tuple, v_tuple)
|
||||
else:
|
||||
total_input = (jvp_input, v)
|
||||
return Jvp(fn)(*total_input)
|
||||
|
||||
|
||||
def grad(fn, grad_first_param=False):
|
||||
"""
|
||||
A wrapper function to generate the gradient function for the input function.
|
||||
|
|
|
@ -19,7 +19,7 @@ import pytest
|
|||
import mindspore.nn as nn
|
||||
import mindspore.context as context
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops.functional import jvp
|
||||
from mindspore.nn.grad import Jvp
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
@ -53,7 +53,7 @@ def test_jvp_single_input_single_output_default_v_graph():
|
|||
net = SingleInputSingleOutputNet()
|
||||
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
|
||||
expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
|
||||
primal, grad = jvp(net, x, v)
|
||||
primal, grad = Jvp(net)(x, v)
|
||||
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
|
||||
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
|
||||
|
||||
|
@ -67,7 +67,7 @@ def test_jvp_single_input_single_output_custom_v_graph():
|
|||
net = SingleInputSingleOutputNet()
|
||||
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
|
||||
expect_grad = Tensor(np.array([[3, 24], [81, 192]]).astype(np.float32))
|
||||
primal, grad = jvp(net, x, v)
|
||||
primal, grad = Jvp(net)(x, v)
|
||||
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
|
||||
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
|
||||
|
||||
|
@ -83,7 +83,7 @@ def test_jvp_single_input_multiple_outputs_default_v_graph():
|
|||
expect_primal_1 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
|
||||
expect_grad_0 = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
|
||||
expect_grad_1 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32))
|
||||
primal, grad = jvp(net, x, v)
|
||||
primal, grad = Jvp(net)(x, v)
|
||||
assert isinstance(primal, tuple)
|
||||
assert len(primal) == 2
|
||||
assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
|
||||
|
@ -105,7 +105,7 @@ def test_jvp_single_input_multiple_outputs_custom_v_graph():
|
|||
expect_primal_1 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
|
||||
expect_grad_0 = Tensor(np.array([[3, 24], [81, 192]]).astype(np.float32))
|
||||
expect_grad_1 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
|
||||
primal, grad = jvp(net, x, v)
|
||||
primal, grad = Jvp(net)(x, v)
|
||||
assert isinstance(primal, tuple)
|
||||
assert len(primal) == 2
|
||||
assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
|
||||
|
@ -126,7 +126,7 @@ def test_jvp_multiple_inputs_single_output_default_v_graph():
|
|||
net = MultipleInputSingleOutputNet()
|
||||
expect_primal = Tensor(np.array([[5, 10], [15, 20]]).astype(np.float32))
|
||||
expect_grad = Tensor(np.array([[5, 5], [5, 5]]).astype(np.float32))
|
||||
primal, grad = jvp(net, (x, y), (v, v))
|
||||
primal, grad = Jvp(net)(x, y, (v, v))
|
||||
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
|
||||
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
|
||||
|
||||
|
@ -142,7 +142,7 @@ def test_jvp_multiple_inputs_single_output_custom_v_graph():
|
|||
net = MultipleInputSingleOutputNet()
|
||||
expect_primal = Tensor(np.array([[5, 10], [15, 20]]).astype(np.float32))
|
||||
expect_grad = Tensor(np.array([[5, 8], [11, 14]]).astype(np.float32))
|
||||
primal, grad = jvp(net, (x, y), (v1, v2))
|
||||
primal, grad = Jvp(net)(x, y, (v1, v2))
|
||||
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
|
||||
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
|
||||
|
||||
|
@ -159,7 +159,7 @@ def test_jvp_multiple_inputs_multiple_outputs_default_v_graph():
|
|||
expect_primal_1 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
|
||||
expect_grad_0 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32))
|
||||
expect_grad_1 = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
|
||||
primal, grad = jvp(net, (x, y), (v, v))
|
||||
primal, grad = Jvp(net)(x, y, (v, v))
|
||||
assert isinstance(primal, tuple)
|
||||
assert len(primal) == 2
|
||||
assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
|
||||
|
@ -183,7 +183,7 @@ def test_jvp_multiple_inputs_multiple_outputs_custom_v_graph():
|
|||
expect_primal_1 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
|
||||
expect_grad_0 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32))
|
||||
expect_grad_1 = Tensor(np.array([[3, 24], [81, 192]]).astype(np.float32))
|
||||
primal, grad = jvp(net, (x, y), (v1, v2))
|
||||
primal, grad = Jvp(net)(x, y, (v1, v2))
|
||||
assert isinstance(primal, tuple)
|
||||
assert len(primal) == 2
|
||||
assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
|
||||
|
@ -192,51 +192,3 @@ def test_jvp_multiple_inputs_multiple_outputs_custom_v_graph():
|
|||
assert len(grad) == 2
|
||||
assert np.allclose(grad[0].asnumpy(), expect_grad_0.asnumpy())
|
||||
assert np.allclose(grad[1].asnumpy(), expect_grad_1.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_jvp_single_input_single_tensor_default_v():
|
||||
x = Tensor(np.array([1]).astype(np.float32))
|
||||
net = SingleInputSingleOutputNet()
|
||||
expect_primal = Tensor(np.array([1]).astype(np.float32))
|
||||
expect_grad = Tensor(np.array([3]).astype(np.float32))
|
||||
primal, grad = jvp(net, x)
|
||||
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
|
||||
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_jvp_single_input_multiple_tensors_default_v():
|
||||
x = Tensor(np.array([1, 2]).astype(np.float32))
|
||||
net = SingleInputSingleOutputNet()
|
||||
with pytest.raises(ValueError) as ex:
|
||||
jvp(net, x)
|
||||
assert "The vector v can only be None if the input is " in str(ex.value)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_jvp_single_input_multiple_tensors_wrong_shape_v():
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
v = Tensor(np.array([[1], [4]]).astype(np.float32))
|
||||
net = SingleInputSingleOutputNet()
|
||||
with pytest.raises(ValueError) as ex:
|
||||
jvp(net, x, v)
|
||||
assert "The tensor shape at the index 0 of v" in str(ex.value)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_jvp_single_input_multiple_tensors_wrong_tuple_v():
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
v = (Tensor(np.array([[1], [4]]).astype(np.float32)), Tensor(np.array([[1], [4]]).astype(np.float32)))
|
||||
net = SingleInputSingleOutputNet()
|
||||
with pytest.raises(ValueError) as ex:
|
||||
jvp(net, x, v)
|
||||
assert "v is not the same size as the function inputs." in str(ex.value)
|
||||
|
|
|
@ -19,7 +19,7 @@ import pytest
|
|||
import mindspore.nn as nn
|
||||
import mindspore.context as context
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops.functional import jvp
|
||||
from mindspore.nn.grad import Jvp
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
|
||||
|
@ -52,7 +52,7 @@ def test_jvp_single_input_single_output_default_v_pynative():
|
|||
net = SingleInputSingleOutputNet()
|
||||
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
|
||||
expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
|
||||
primal, grad = jvp(net, x, v)
|
||||
primal, grad = Jvp(net)(x, v)
|
||||
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
|
||||
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
|
||||
|
||||
|
@ -66,7 +66,7 @@ def test_jvp_single_input_single_output_custom_v_pynative():
|
|||
net = SingleInputSingleOutputNet()
|
||||
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
|
||||
expect_grad = Tensor(np.array([[3, 24], [81, 192]]).astype(np.float32))
|
||||
primal, grad = jvp(net, x, v)
|
||||
primal, grad = Jvp(net)(x, v)
|
||||
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
|
||||
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
|
||||
|
||||
|
@ -82,7 +82,7 @@ def test_jvp_single_input_multiple_outputs_default_v_pynative():
|
|||
expect_primal_1 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
|
||||
expect_grad_0 = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
|
||||
expect_grad_1 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32))
|
||||
primal, grad = jvp(net, x, v)
|
||||
primal, grad = Jvp(net)(x, v)
|
||||
assert isinstance(primal, tuple)
|
||||
assert len(primal) == 2
|
||||
assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
|
||||
|
@ -104,7 +104,7 @@ def test_jvp_single_input_multiple_outputs_custom_v_pynative():
|
|||
expect_primal_1 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
|
||||
expect_grad_0 = Tensor(np.array([[3, 24], [81, 192]]).astype(np.float32))
|
||||
expect_grad_1 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
|
||||
primal, grad = jvp(net, x, v)
|
||||
primal, grad = Jvp(net)(x, v)
|
||||
assert isinstance(primal, tuple)
|
||||
assert len(primal) == 2
|
||||
assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
|
||||
|
@ -127,7 +127,7 @@ def test_jvp_multiple_inputs_multiple_outputs_default_v_pynative():
|
|||
expect_primal_1 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
|
||||
expect_grad_0 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32))
|
||||
expect_grad_1 = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
|
||||
primal, grad = jvp(net, (x, y), (v, v))
|
||||
primal, grad = Jvp(net)(x, y, (v, v))
|
||||
assert isinstance(primal, tuple)
|
||||
assert len(primal) == 2
|
||||
assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
|
||||
|
@ -151,7 +151,7 @@ def test_jvp_multiple_inputs_multiple_outputs_custom_v_pynative():
|
|||
expect_primal_1 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
|
||||
expect_grad_0 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32))
|
||||
expect_grad_1 = Tensor(np.array([[3, 24], [81, 192]]).astype(np.float32))
|
||||
primal, grad = jvp(net, (x, y), (v1, v2))
|
||||
primal, grad = Jvp(net)(x, y, (v1, v2))
|
||||
assert isinstance(primal, tuple)
|
||||
assert len(primal) == 2
|
||||
assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
|
||||
|
@ -172,7 +172,7 @@ def test_jvp_multiple_inputs_single_output_default_v_pynative():
|
|||
net = MultipleInputSingleOutputNet()
|
||||
expect_primal = Tensor(np.array([[5, 10], [15, 20]]).astype(np.float32))
|
||||
expect_grad = Tensor(np.array([[5, 5], [5, 5]]).astype(np.float32))
|
||||
primal, grad = jvp(net, (x, y), (v, v))
|
||||
primal, grad = Jvp(net)(x, y, (v, v))
|
||||
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
|
||||
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
|
||||
|
||||
|
@ -188,6 +188,6 @@ def test_jvp_multiple_inputs_single_output_custom_v_pynative():
|
|||
net = MultipleInputSingleOutputNet()
|
||||
expect_primal = Tensor(np.array([[5, 10], [15, 20]]).astype(np.float32))
|
||||
expect_grad = Tensor(np.array([[5, 8], [11, 14]]).astype(np.float32))
|
||||
primal, grad = jvp(net, (x, y), (v1, v2))
|
||||
primal, grad = Jvp(net)(x, y, (v1, v2))
|
||||
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
|
||||
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
|
||||
|
|
|
@ -0,0 +1,71 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""test jvp in graph mode"""
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.nn as nn
|
||||
import mindspore.context as context
|
||||
from mindspore import Tensor
|
||||
from mindspore.nn.grad import Vjp
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
class SingleInputNet(nn.Cell):
|
||||
def construct(self, x):
|
||||
return x**3
|
||||
|
||||
|
||||
class MultipleInputsOutputNet(nn.Cell):
|
||||
def construct(self, x, y):
|
||||
return 2*x, y**3
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_vjp_single_input_graph():
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||
net = SingleInputNet()
|
||||
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
|
||||
expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
|
||||
primal, grad = Vjp(net)(x, v)
|
||||
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
|
||||
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
|
||||
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_vjp_multiple_inputs_default_v_graph():
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||
net = MultipleInputsOutputNet()
|
||||
expect_primal_0 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
|
||||
expect_primal_1 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
|
||||
expect_grad_0 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32))
|
||||
expect_grad_1 = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
|
||||
primal, grad = Vjp(net)(x, y, (v, v))
|
||||
assert isinstance(primal, tuple)
|
||||
assert len(primal) == 2
|
||||
assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
|
||||
assert np.allclose(primal[1].asnumpy(), expect_primal_1.asnumpy())
|
||||
assert isinstance(grad, tuple)
|
||||
assert len(grad) == 2
|
||||
assert np.allclose(grad[0].asnumpy(), expect_grad_0.asnumpy())
|
||||
assert np.allclose(grad[1].asnumpy(), expect_grad_1.asnumpy())
|
Loading…
Reference in New Issue