forked from mindspore-Ecosystem/mindspore
!40599 [feature] change vjp
Merge pull request !40599 from chenzhuo/master_vjp
This commit is contained in:
commit
242e5c819e
|
@ -1,21 +1,19 @@
|
|||
mindspore.ops.vjp
|
||||
=================
|
||||
|
||||
.. py:function:: mindspore.ops.vjp(fn, inputs, v)
|
||||
.. py:function:: mindspore.ops.vjp(fn, inputs, has_aux=False)
|
||||
|
||||
计算给定网络的向量雅可比积(vector-jacobian-product, VJP)。VJP对应 `反向模式自动微分 <https://www.mindspore.cn/docs/zh-CN/master/design/auto_gradient.html#反向自动微分>`_。
|
||||
|
||||
.. note::
|
||||
此接口未來会变动。
|
||||
|
||||
参数:
|
||||
- **fn** (Union[Function, Cell]) - 待求导的函数或网络。以Tensor为入参,返回Tensor或Tensor数组。
|
||||
- **inputs** (Union[Tensor, tuple[Tensor], list[Tensor]]) - 输入网络 `fn` 的入参。
|
||||
- **v** (Union[Tensor, tuple[Tensor], list[Tensor]]) - 与雅可比矩阵相乘的向量,shape和type与网络的正向计算结果一致。
|
||||
- **has_aux** (bool) - 若 `has_aux` 为True,只有 `fn` 的第一个输出参与 `fn` 的求导,其他输出将直接返回。此时, `fn` 的输出数量必须超过一个。默认值:False。
|
||||
|
||||
返回:
|
||||
- **net_output** (Union[Tensor, tuple[Tensor]]) - 输入网络的正向计算结果。
|
||||
- **vjp** (Union[NoneType, int, tuple[int]]) - 向量雅可比积的结果。
|
||||
- **vjp_fn** (Function) - 用于求解向量雅可比积的函数。接收shape和type与 `net_out` 一致的输入。
|
||||
- **aux_value** (Union[Tensor, tuple[Tensor]], optional) - 若 `has_aux` 为True,才返回 `aux_value` 。`aux_value` 是 `fn(inputs)` 的第一个除外的其他输出,且不参与 `fn` 的求导。
|
||||
|
||||
异常:
|
||||
- **TypeError** - `inputs` 或 `v` 类型不符合要求。
|
||||
|
|
|
@ -195,27 +195,3 @@ class Vjp(Cell):
|
|||
else:
|
||||
gradient_output = self.grad(self.fn)(*args)
|
||||
return output, gradient_output
|
||||
|
||||
|
||||
class _VjpInner(Cell):
|
||||
"""
|
||||
Computes the dot product between a vector `v` and the Jacobian of the given network at the point
|
||||
given by the inputs. This class implements the inner process of function vjp.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(_VjpInner, self).__init__()
|
||||
self.grad = C.GradOperation(get_all=True, sens_param=True)
|
||||
self.grad_single_value = C.GradOperation(sens_param=True)
|
||||
self.tuple_len = Primitive("tuple_len")
|
||||
|
||||
def construct(self, *args):
|
||||
fn = args[0]
|
||||
front_input = args[1:-1]
|
||||
input_with_v = args[1:]
|
||||
output = fn(*front_input)
|
||||
if self.tuple_len(front_input) == 1:
|
||||
gradient_output = self.grad_single_value(fn)(*input_with_v)
|
||||
else:
|
||||
gradient_output = self.grad(fn)(*input_with_v)
|
||||
return output, gradient_output
|
||||
|
|
|
@ -21,11 +21,10 @@ from mindspore.common import ms_function
|
|||
from mindspore.common import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.nn.grad.cell_grad import _JvpInner
|
||||
from mindspore.nn.grad.cell_grad import _VjpInner
|
||||
from mindspore.nn.grad.cell_grad import _LinearizeInner
|
||||
from mindspore.ops.primitive import constexpr
|
||||
from mindspore.ops.function import ones, expand_dims
|
||||
from mindspore.ops.composite import _Grad, _TaylorOperation
|
||||
from mindspore.ops.composite import _Grad, _TaylorOperation, GradOperation
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
cast = P.Cast()
|
||||
|
@ -664,24 +663,41 @@ def linearize(fn, inputs):
|
|||
return output, partial(_wrap_container, output, *inputs)
|
||||
|
||||
|
||||
def vjp(fn, inputs, v):
|
||||
def _check_tensor(inputs):
|
||||
if not isinstance(inputs, (Tensor, tuple)):
|
||||
raise TypeError("The inputs type must be Tensor.")
|
||||
if isinstance(inputs, tuple):
|
||||
for item in inputs:
|
||||
if not isinstance(item, (Tensor, tuple, list)):
|
||||
raise TypeError("The inputs type must be Tensor.")
|
||||
return True
|
||||
|
||||
|
||||
vjp_grad = GradOperation(get_all=True, sens_param=True)
|
||||
|
||||
|
||||
def vjp(fn, *inputs, has_aux=False):
|
||||
"""
|
||||
Compute the vector-jacobian-product of the given network. `vjp` matches
|
||||
`reverse-mode differentiation <https://www.mindspore.cn/docs/en/master/design/auto_gradient.html#reverse-mode-ad>`_.
|
||||
|
||||
Note:
|
||||
This function is subjected to change in the future.
|
||||
|
||||
Args:
|
||||
fn (Union[Function, Cell]): The function or net that takes Tensor inputs and returns single Tensor or tuple of
|
||||
Tensors.
|
||||
inputs (Union[Tensor, tuple[Tensor], list[Tensor]]): The inputs to `fn` .
|
||||
v (Union[Tensor, tuple[Tensor], list[Tensor]]): The vector in vector-jacobian-product. The shape and type of `v`
|
||||
should be the same as `fn(inputs)` .
|
||||
has_aux (bool): If True, only the first output of `fn` contributes the gradient of `fn`, while the other outputs
|
||||
will be returned straightly. It means the `fn` must return more than one outputs in this case.
|
||||
Default: False.
|
||||
|
||||
Returns:
|
||||
- **net_output** (Union[Tensor, tuple[Tensor]]) - The result of `fn(inputs)` .
|
||||
- **vjp** (Union[Tensor, tuple[Tensor]]) - The result of vector-jacobian-product.
|
||||
Forward outputs and function to calculate vjp.
|
||||
|
||||
- **net_output** (Union[Tensor, tuple[Tensor]]) - The output of `fn(inputs)`. Specially, when `has_aux` is set
|
||||
True, `netout` is the first output of `fn(inputs)`.
|
||||
- **vjp_fn** (Function) - To calculate vector-jacobian-product. Its inputs are the vectors whose shape and
|
||||
type should be the same as `netout` .
|
||||
- **aux_value** (Union[Tensor, tuple[Tensor]], optional) - When `has_aux` is True, `aux_value` will be returned.
|
||||
It means the second to last outputs of `fn(inputs)`. Specially, `aux_value` does not contribute to gradient.
|
||||
|
||||
Raises:
|
||||
TypeError: `inputs` or `v` does not belong to required types.
|
||||
|
@ -690,7 +706,7 @@ def vjp(fn, inputs, v):
|
|||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> from mindspore import ops
|
||||
>>> from mindspore.ops import vjp
|
||||
>>> from mindspore import Tensor
|
||||
>>> class Net(nn.Cell):
|
||||
... def construct(self, x, y):
|
||||
|
@ -698,32 +714,62 @@ def vjp(fn, inputs, v):
|
|||
>>> x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
>>> y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
>>> v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||
>>> output = ops.vjp(Net(), (x, y), v)
|
||||
>>> print(output[0])
|
||||
>>> outputs, vjp_fn = vjp(Net(), x, y)
|
||||
>>> print(outputs)
|
||||
[[ 2. 10.]
|
||||
[30. 68.]]
|
||||
>>> print(output[1])
|
||||
>>> gradient = vjp_fn(v)
|
||||
>>> print(gradient)
|
||||
(Tensor(shape=[2, 2], dtype=Float32, value=
|
||||
[[ 3.00000000e+00, 1.20000000e+01],
|
||||
[ 2.70000000e+01, 4.80000000e+01]]), Tensor(shape=[2, 2], dtype=Float32, value=
|
||||
[[ 1.00000000e+00, 1.00000000e+00],
|
||||
[ 1.00000000e+00, 1.00000000e+00]]))
|
||||
>>> def fn(x, y):
|
||||
... return 2 * x + y, y ** 3
|
||||
>>> outputs, vjp_fn, aux = vjp(Net(), x, y, has_aux=True)
|
||||
>>> gradient = vjp_fn(v)
|
||||
>>> print(outputs)
|
||||
Tensor(shape=[2, 2], dtype=Float32, value=
|
||||
[[ 3.00000000e+00, 6.00000000e+00],
|
||||
[ 9.00000000e+00, 1.20000000e+01]])
|
||||
>>> print(aux)
|
||||
Tensor(shape=[2, 2], dtype=Float32, value=
|
||||
[[ 1.00000000e+00, 8.00000000e+00],
|
||||
[ 2.70000000e+01, 6.40000000e+01]])
|
||||
>>> print(gradient)
|
||||
(Tensor(shape=[2, 2], dtype=Float32, value=
|
||||
[[ 2.00000000e+00, 2.00000000e+00],
|
||||
[ 2.00000000e+00, 2.00000000e+00]]), Tensor(shape=[2, 2], dtype=Float32, value=
|
||||
[[ 1.00000000e+00, 1.00000000e+00],
|
||||
[ 1.00000000e+00, 1.00000000e+00]]))
|
||||
"""
|
||||
vjp_inner = _VjpInner()
|
||||
_check_tensor(inputs)
|
||||
|
||||
@ms_function(hash_args=fn)
|
||||
def wrap_container(*arg):
|
||||
args = arg[:-1]
|
||||
vectors = arg[-1]
|
||||
return vjp_inner(fn, *args, vectors)
|
||||
def aux_fn(*args):
|
||||
outputs = fn(*args)
|
||||
if not isinstance(outputs, tuple) or len(outputs) < 2:
|
||||
raise ValueError("When has_aux is True, origin fn requires more than one outputs.")
|
||||
res = outputs[0]
|
||||
return res
|
||||
|
||||
if not isinstance(inputs, (Tensor, tuple, list)) or not isinstance(v, (Tensor, tuple, list)):
|
||||
_raise_type_error()
|
||||
if isinstance(v, list):
|
||||
v = tuple(v)
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
return wrap_container(*inputs, v)
|
||||
return wrap_container(inputs, v)
|
||||
if has_aux:
|
||||
fn_ = aux_fn
|
||||
else:
|
||||
fn_ = fn
|
||||
|
||||
def wrap_container(*v):
|
||||
_check_tensor(v)
|
||||
if len(v) == 1:
|
||||
return vjp_grad(fn_)(*inputs, v[0])
|
||||
return vjp_grad(fn_)(*inputs, v)
|
||||
|
||||
res = fn(*inputs)
|
||||
if has_aux:
|
||||
if len(res) == 2:
|
||||
return res[0], wrap_container, res[1]
|
||||
return res[0], wrap_container, res[1:]
|
||||
return res, wrap_container
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
|
|
@ -26,12 +26,12 @@ context.set_context(mode=context.GRAPH_MODE)
|
|||
|
||||
class SingleInputNet(nn.Cell):
|
||||
def construct(self, x):
|
||||
return x**3
|
||||
return x ** 3
|
||||
|
||||
|
||||
class MultipleInputsOutputNet(nn.Cell):
|
||||
def construct(self, x, y):
|
||||
return 2*x, y**3
|
||||
return 2 * x, y ** 3
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
@ -48,10 +48,10 @@ def test_vjp_single_input_graph():
|
|||
net = SingleInputNet()
|
||||
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
|
||||
expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
|
||||
primal, grad = vjp(net, x, v)
|
||||
primal, grad_fn = vjp(net, x)
|
||||
gradient = grad_fn(v)
|
||||
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
|
||||
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
|
||||
|
||||
assert np.allclose(gradient[0].asnumpy(), expect_grad.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
@ -71,15 +71,16 @@ def test_vjp_multiple_inputs_default_v_graph():
|
|||
expect_primal_1 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
|
||||
expect_grad_0 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32))
|
||||
expect_grad_1 = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
|
||||
primal, grad = vjp(net, (x, y), (v, v))
|
||||
primal, grad_fn = vjp(net, x, y)
|
||||
gradient = grad_fn(v, v)
|
||||
assert isinstance(primal, tuple)
|
||||
assert len(primal) == 2
|
||||
assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
|
||||
assert np.allclose(primal[1].asnumpy(), expect_primal_1.asnumpy())
|
||||
assert isinstance(grad, tuple)
|
||||
assert len(grad) == 2
|
||||
assert np.allclose(grad[0].asnumpy(), expect_grad_0.asnumpy())
|
||||
assert np.allclose(grad[1].asnumpy(), expect_grad_1.asnumpy())
|
||||
assert isinstance(gradient, tuple)
|
||||
assert len(gradient) == 2
|
||||
assert np.allclose(gradient[0].asnumpy(), expect_grad_0.asnumpy())
|
||||
assert np.allclose(gradient[1].asnumpy(), expect_grad_1.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
@ -97,14 +98,15 @@ def test_vjp_ms_function_single_input_single_output_default_v_graph():
|
|||
|
||||
@ms_function
|
||||
def vjp_with_ms_function(inputs, vectors):
|
||||
output, vjp_grad = vjp(net, inputs, vectors)
|
||||
output, grad_fn = vjp(net, inputs)
|
||||
vjp_grad = grad_fn(vectors)
|
||||
return output, vjp_grad
|
||||
|
||||
primal, grad = vjp_with_ms_function(x, v)
|
||||
primal, gradient = vjp_with_ms_function(x, v)
|
||||
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
|
||||
expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
|
||||
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
|
||||
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
|
||||
assert np.allclose(gradient[0].asnumpy(), expect_grad.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
@ -120,13 +122,14 @@ def test_vjp_input_function_single_input_single_output_default_v_graph():
|
|||
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||
|
||||
def test_function(inputs):
|
||||
return inputs**3
|
||||
return inputs ** 3
|
||||
|
||||
primal, grad = vjp(test_function, x, v)
|
||||
primal, grad_fn = vjp(test_function, x)
|
||||
gradient = grad_fn(v)
|
||||
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
|
||||
expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
|
||||
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
|
||||
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
|
||||
assert np.allclose(gradient[0].asnumpy(), expect_grad.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
@ -147,12 +150,46 @@ def test_vjp_construct_single_input_single_output_default_v_graph():
|
|||
self.net = network
|
||||
|
||||
def construct(self, inputs, vectors):
|
||||
net_out, vjp_out = vjp(self.net, inputs, vectors)
|
||||
net_out, grad_fn = vjp(self.net, inputs)
|
||||
vjp_out = grad_fn(vectors)
|
||||
return net_out, vjp_out
|
||||
|
||||
test_net_graph = Net(SingleInputNet())
|
||||
primal, grad = test_net_graph(x, v)
|
||||
primal, gradient = test_net_graph(x, v)
|
||||
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
|
||||
expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
|
||||
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
|
||||
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
|
||||
assert np.allclose(gradient[0].asnumpy(), expect_grad.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_vjp_multiple_outputs_with_has_aux_graph():
|
||||
"""
|
||||
Features: Function vjp
|
||||
Description: Test vjp with multiple inputs, multiple outputs with set_aux as True in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
def fn(x, y):
|
||||
return 2 * x + y, y ** 3
|
||||
|
||||
def fn2(*args):
|
||||
return fn(*args)
|
||||
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||
expect_primal = Tensor(np.array([[3, 6], [9, 12]]).astype(np.float32))
|
||||
expect_aux = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
|
||||
expect_grad_0 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32))
|
||||
expect_grad_1 = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||
primal, grad_fn, aux = vjp(fn2, x, y, has_aux=True)
|
||||
gradient = grad_fn(v)
|
||||
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
|
||||
assert np.allclose(aux.asnumpy(), expect_aux.asnumpy())
|
||||
assert isinstance(gradient, tuple)
|
||||
assert len(gradient) == 2
|
||||
assert np.allclose(gradient[0].asnumpy(), expect_grad_0.asnumpy())
|
||||
assert np.allclose(gradient[1].asnumpy(), expect_grad_1.asnumpy())
|
||||
|
|
|
@ -25,12 +25,12 @@ context.set_context(mode=context.PYNATIVE_MODE)
|
|||
|
||||
class SingleInputNet(nn.Cell):
|
||||
def construct(self, x):
|
||||
return x**3
|
||||
return x ** 3
|
||||
|
||||
|
||||
class MultipleInputsOutputNet(nn.Cell):
|
||||
def construct(self, x, y):
|
||||
return 2*x, y**3
|
||||
return 2 * x, y ** 3
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
@ -47,9 +47,10 @@ def test_vjp_single_input_pynative():
|
|||
net = SingleInputNet()
|
||||
expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
|
||||
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
|
||||
primal, grad = vjp(net, x, v)
|
||||
primal, grad_fn = vjp(net, x)
|
||||
gradient = grad_fn(v)
|
||||
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
|
||||
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
|
||||
assert np.allclose(gradient[0].asnumpy(), expect_grad.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
@ -69,11 +70,12 @@ def test_vjp_multiple_inputs_default_v_pynative():
|
|||
expect_grad_1 = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
|
||||
expect_primal_0 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
|
||||
expect_primal_1 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
|
||||
primal, grad = vjp(net, (x, y), (v, v))
|
||||
assert isinstance(grad, tuple)
|
||||
assert len(grad) == 2
|
||||
assert np.allclose(grad[0].asnumpy(), expect_grad_0.asnumpy())
|
||||
assert np.allclose(grad[1].asnumpy(), expect_grad_1.asnumpy())
|
||||
primal, grad_fn = vjp(net, x, y)
|
||||
gradient = grad_fn(v, v)
|
||||
assert isinstance(gradient, tuple)
|
||||
assert len(gradient) == 2
|
||||
assert np.allclose(gradient[0].asnumpy(), expect_grad_0.asnumpy())
|
||||
assert np.allclose(gradient[1].asnumpy(), expect_grad_1.asnumpy())
|
||||
assert isinstance(primal, tuple)
|
||||
assert len(primal) == 2
|
||||
assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
|
||||
|
@ -93,13 +95,14 @@ def test_vjp_input_function_single_input_single_output_default_v_pynative():
|
|||
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||
|
||||
def test_function(inputs):
|
||||
return inputs**3
|
||||
return inputs ** 3
|
||||
|
||||
primal, grad = vjp(test_function, x, v)
|
||||
primal, grad_fn = vjp(test_function, x)
|
||||
gradient = grad_fn(v)
|
||||
expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
|
||||
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
|
||||
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
|
||||
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
|
||||
assert np.allclose(gradient[0].asnumpy(), expect_grad.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
@ -120,12 +123,46 @@ def test_vjp_construct_single_input_single_output_default_v_pynative():
|
|||
self.net = network
|
||||
|
||||
def construct(self, inputs, vectors):
|
||||
net_out, vjp_out = vjp(self.net, inputs, vectors)
|
||||
net_out, grad_fn = vjp(self.net, inputs)
|
||||
vjp_out = grad_fn(vectors)
|
||||
return net_out, vjp_out
|
||||
|
||||
test_net_pynative = Net(SingleInputNet())
|
||||
primal, grad = test_net_pynative(x, v)
|
||||
primal, gradient = test_net_pynative(x, v)
|
||||
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
|
||||
expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
|
||||
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
|
||||
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
|
||||
assert np.allclose(gradient[0].asnumpy(), expect_grad.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_vjp_multiple_outputs_with_has_aux_pynative():
|
||||
"""
|
||||
Features: Function vjp
|
||||
Description: Test vjp with multiple inputs, multiple outputs with set_aux as True in pynative mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
def fn(x, y):
|
||||
return 2 * x + y, y ** 3
|
||||
|
||||
def fn2(*args):
|
||||
return fn(*args)
|
||||
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||
expect_grad_0 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32))
|
||||
expect_grad_1 = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||
expect_primal = Tensor(np.array([[3, 6], [9, 12]]).astype(np.float32))
|
||||
expect_aux = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
|
||||
primal, grad_fn, aux = vjp(fn2, x, y, has_aux=True)
|
||||
gradient = grad_fn(v)
|
||||
assert isinstance(gradient, tuple)
|
||||
assert len(gradient) == 2
|
||||
assert np.allclose(gradient[0].asnumpy(), expect_grad_0.asnumpy())
|
||||
assert np.allclose(gradient[1].asnumpy(), expect_grad_1.asnumpy())
|
||||
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
|
||||
assert np.allclose(aux.asnumpy(), expect_aux.asnumpy())
|
||||
|
|
|
@ -25,12 +25,12 @@ context.set_context(mode=context.GRAPH_MODE)
|
|||
|
||||
class SingleInputNet(nn.Cell):
|
||||
def construct(self, x):
|
||||
return x**3
|
||||
return x ** 3
|
||||
|
||||
|
||||
class MultipleInputsOutputNet(nn.Cell):
|
||||
def construct(self, x, y):
|
||||
return 2*x, y**3
|
||||
return 2 * x, y ** 3
|
||||
|
||||
|
||||
def test_vjp_single_input_graph():
|
||||
|
@ -42,7 +42,7 @@ def test_vjp_single_input_graph():
|
|||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||
net = SingleInputNet()
|
||||
vjp(net, x, v)
|
||||
vjp(net, x)[1](v)
|
||||
|
||||
|
||||
def test_vjp_multiple_inputs_default_v_graph():
|
||||
|
@ -55,7 +55,7 @@ def test_vjp_multiple_inputs_default_v_graph():
|
|||
y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||
net = MultipleInputsOutputNet()
|
||||
vjp(net, (x, y), (v, v))
|
||||
vjp(net, x, y)[1](v, v)
|
||||
|
||||
|
||||
def test_vjp_wrong_input_type_graph():
|
||||
|
@ -68,4 +68,4 @@ def test_vjp_wrong_input_type_graph():
|
|||
v = 1
|
||||
net = SingleInputNet()
|
||||
with pytest.raises(TypeError):
|
||||
vjp(net, x, v)
|
||||
vjp(net, x)[1](v)
|
||||
|
|
|
@ -25,12 +25,12 @@ context.set_context(mode=context.PYNATIVE_MODE)
|
|||
|
||||
class SingleInputNet(nn.Cell):
|
||||
def construct(self, x):
|
||||
return x**3
|
||||
return x ** 3
|
||||
|
||||
|
||||
class MultipleInputsOutputNet(nn.Cell):
|
||||
def construct(self, x, y):
|
||||
return 2*x, y**3
|
||||
return 2 * x, y ** 3
|
||||
|
||||
|
||||
def test_vjp_single_input_pynative():
|
||||
|
@ -42,7 +42,7 @@ def test_vjp_single_input_pynative():
|
|||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||
net = SingleInputNet()
|
||||
vjp(net, x, v)
|
||||
vjp(net, x)[1](v)
|
||||
|
||||
|
||||
def test_vjp_multiple_inputs_default_v_pynative():
|
||||
|
@ -55,7 +55,7 @@ def test_vjp_multiple_inputs_default_v_pynative():
|
|||
y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||
net = MultipleInputsOutputNet()
|
||||
vjp(net, (x, y), (v, v))
|
||||
vjp(net, x, y)[1](v, v)
|
||||
|
||||
|
||||
def test_vjp_wrong_input_type_pynative():
|
||||
|
@ -68,4 +68,4 @@ def test_vjp_wrong_input_type_pynative():
|
|||
v = 1
|
||||
net = SingleInputNet()
|
||||
with pytest.raises(TypeError):
|
||||
vjp(net, x, v)
|
||||
vjp(net, x)[1](v)
|
||||
|
|
Loading…
Reference in New Issue