!40599 [feature] change vjp

Merge pull request !40599 from chenzhuo/master_vjp
This commit is contained in:
i-robot 2022-09-15 01:41:13 +00:00 committed by Gitee
commit 242e5c819e
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
7 changed files with 195 additions and 101 deletions

View File

@ -1,21 +1,19 @@
mindspore.ops.vjp
=================
.. py:function:: mindspore.ops.vjp(fn, inputs, v)
.. py:function:: mindspore.ops.vjp(fn, inputs, has_aux=False)
计算给定网络的向量雅可比积(vector-jacobian-product, VJP)。VJP对应 `反向模式自动微分 <https://www.mindspore.cn/docs/zh-CN/master/design/auto_gradient.html#反向自动微分>`_
.. note::
此接口未來会变动。
参数:
- **fn** (Union[Function, Cell]) - 待求导的函数或网络。以Tensor为入参返回Tensor或Tensor数组。
- **inputs** (Union[Tensor, tuple[Tensor], list[Tensor]]) - 输入网络 `fn` 的入参。
- **v** (Union[Tensor, tuple[Tensor], list[Tensor]]) - 与雅可比矩阵相乘的向量shape和type与网络的正向计算结果一致
- **has_aux** (bool) - 若 `has_aux` 为True只有 `fn` 的第一个输出参与 `fn` 的求导,其他输出将直接返回。此时, `fn` 的输出数量必须超过一个。默认值False
返回:
- **net_output** (Union[Tensor, tuple[Tensor]]) - 输入网络的正向计算结果。
- **vjp** (Union[NoneType, int, tuple[int]]) - 向量雅可比积的结果。
- **vjp_fn** (Function) - 用于求解向量雅可比积的函数。接收shape和type与 `net_out` 一致的输入。
- **aux_value** (Union[Tensor, tuple[Tensor]], optional) - 若 `has_aux` 为True才返回 `aux_value``aux_value``fn(inputs)` 的第一个除外的其他输出,且不参与 `fn` 的求导。
异常:
- **TypeError** - `inputs``v` 类型不符合要求。

View File

@ -195,27 +195,3 @@ class Vjp(Cell):
else:
gradient_output = self.grad(self.fn)(*args)
return output, gradient_output
class _VjpInner(Cell):
"""
Computes the dot product between a vector `v` and the Jacobian of the given network at the point
given by the inputs. This class implements the inner process of function vjp.
"""
def __init__(self):
super(_VjpInner, self).__init__()
self.grad = C.GradOperation(get_all=True, sens_param=True)
self.grad_single_value = C.GradOperation(sens_param=True)
self.tuple_len = Primitive("tuple_len")
def construct(self, *args):
fn = args[0]
front_input = args[1:-1]
input_with_v = args[1:]
output = fn(*front_input)
if self.tuple_len(front_input) == 1:
gradient_output = self.grad_single_value(fn)(*input_with_v)
else:
gradient_output = self.grad(fn)(*input_with_v)
return output, gradient_output

View File

@ -21,11 +21,10 @@ from mindspore.common import ms_function
from mindspore.common import Tensor
from mindspore.common import dtype as mstype
from mindspore.nn.grad.cell_grad import _JvpInner
from mindspore.nn.grad.cell_grad import _VjpInner
from mindspore.nn.grad.cell_grad import _LinearizeInner
from mindspore.ops.primitive import constexpr
from mindspore.ops.function import ones, expand_dims
from mindspore.ops.composite import _Grad, _TaylorOperation
from mindspore.ops.composite import _Grad, _TaylorOperation, GradOperation
from mindspore.ops import operations as P
cast = P.Cast()
@ -664,24 +663,41 @@ def linearize(fn, inputs):
return output, partial(_wrap_container, output, *inputs)
def vjp(fn, inputs, v):
def _check_tensor(inputs):
if not isinstance(inputs, (Tensor, tuple)):
raise TypeError("The inputs type must be Tensor.")
if isinstance(inputs, tuple):
for item in inputs:
if not isinstance(item, (Tensor, tuple, list)):
raise TypeError("The inputs type must be Tensor.")
return True
vjp_grad = GradOperation(get_all=True, sens_param=True)
def vjp(fn, *inputs, has_aux=False):
"""
Compute the vector-jacobian-product of the given network. `vjp` matches
`reverse-mode differentiation <https://www.mindspore.cn/docs/en/master/design/auto_gradient.html#reverse-mode-ad>`_.
Note:
This function is subjected to change in the future.
Args:
fn (Union[Function, Cell]): The function or net that takes Tensor inputs and returns single Tensor or tuple of
Tensors.
inputs (Union[Tensor, tuple[Tensor], list[Tensor]]): The inputs to `fn` .
v (Union[Tensor, tuple[Tensor], list[Tensor]]): The vector in vector-jacobian-product. The shape and type of `v`
should be the same as `fn(inputs)` .
has_aux (bool): If True, only the first output of `fn` contributes the gradient of `fn`, while the other outputs
will be returned straightly. It means the `fn` must return more than one outputs in this case.
Default: False.
Returns:
- **net_output** (Union[Tensor, tuple[Tensor]]) - The result of `fn(inputs)` .
- **vjp** (Union[Tensor, tuple[Tensor]]) - The result of vector-jacobian-product.
Forward outputs and function to calculate vjp.
- **net_output** (Union[Tensor, tuple[Tensor]]) - The output of `fn(inputs)`. Specially, when `has_aux` is set
True, `netout` is the first output of `fn(inputs)`.
- **vjp_fn** (Function) - To calculate vector-jacobian-product. Its inputs are the vectors whose shape and
type should be the same as `netout` .
- **aux_value** (Union[Tensor, tuple[Tensor]], optional) - When `has_aux` is True, `aux_value` will be returned.
It means the second to last outputs of `fn(inputs)`. Specially, `aux_value` does not contribute to gradient.
Raises:
TypeError: `inputs` or `v` does not belong to required types.
@ -690,7 +706,7 @@ def vjp(fn, inputs, v):
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> from mindspore import ops
>>> from mindspore.ops import vjp
>>> from mindspore import Tensor
>>> class Net(nn.Cell):
... def construct(self, x, y):
@ -698,32 +714,62 @@ def vjp(fn, inputs, v):
>>> x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
>>> y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
>>> v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
>>> output = ops.vjp(Net(), (x, y), v)
>>> print(output[0])
>>> outputs, vjp_fn = vjp(Net(), x, y)
>>> print(outputs)
[[ 2. 10.]
[30. 68.]]
>>> print(output[1])
>>> gradient = vjp_fn(v)
>>> print(gradient)
(Tensor(shape=[2, 2], dtype=Float32, value=
[[ 3.00000000e+00, 1.20000000e+01],
[ 2.70000000e+01, 4.80000000e+01]]), Tensor(shape=[2, 2], dtype=Float32, value=
[[ 1.00000000e+00, 1.00000000e+00],
[ 1.00000000e+00, 1.00000000e+00]]))
>>> def fn(x, y):
... return 2 * x + y, y ** 3
>>> outputs, vjp_fn, aux = vjp(Net(), x, y, has_aux=True)
>>> gradient = vjp_fn(v)
>>> print(outputs)
Tensor(shape=[2, 2], dtype=Float32, value=
[[ 3.00000000e+00, 6.00000000e+00],
[ 9.00000000e+00, 1.20000000e+01]])
>>> print(aux)
Tensor(shape=[2, 2], dtype=Float32, value=
[[ 1.00000000e+00, 8.00000000e+00],
[ 2.70000000e+01, 6.40000000e+01]])
>>> print(gradient)
(Tensor(shape=[2, 2], dtype=Float32, value=
[[ 2.00000000e+00, 2.00000000e+00],
[ 2.00000000e+00, 2.00000000e+00]]), Tensor(shape=[2, 2], dtype=Float32, value=
[[ 1.00000000e+00, 1.00000000e+00],
[ 1.00000000e+00, 1.00000000e+00]]))
"""
vjp_inner = _VjpInner()
_check_tensor(inputs)
@ms_function(hash_args=fn)
def wrap_container(*arg):
args = arg[:-1]
vectors = arg[-1]
return vjp_inner(fn, *args, vectors)
def aux_fn(*args):
outputs = fn(*args)
if not isinstance(outputs, tuple) or len(outputs) < 2:
raise ValueError("When has_aux is True, origin fn requires more than one outputs.")
res = outputs[0]
return res
if not isinstance(inputs, (Tensor, tuple, list)) or not isinstance(v, (Tensor, tuple, list)):
_raise_type_error()
if isinstance(v, list):
v = tuple(v)
if isinstance(inputs, (tuple, list)):
return wrap_container(*inputs, v)
return wrap_container(inputs, v)
if has_aux:
fn_ = aux_fn
else:
fn_ = fn
def wrap_container(*v):
_check_tensor(v)
if len(v) == 1:
return vjp_grad(fn_)(*inputs, v[0])
return vjp_grad(fn_)(*inputs, v)
res = fn(*inputs)
if has_aux:
if len(res) == 2:
return res[0], wrap_container, res[1]
return res[0], wrap_container, res[1:]
return res, wrap_container
__all__ = [

View File

@ -26,12 +26,12 @@ context.set_context(mode=context.GRAPH_MODE)
class SingleInputNet(nn.Cell):
def construct(self, x):
return x**3
return x ** 3
class MultipleInputsOutputNet(nn.Cell):
def construct(self, x, y):
return 2*x, y**3
return 2 * x, y ** 3
@pytest.mark.level0
@ -48,10 +48,10 @@ def test_vjp_single_input_graph():
net = SingleInputNet()
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
primal, grad = vjp(net, x, v)
primal, grad_fn = vjp(net, x)
gradient = grad_fn(v)
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
assert np.allclose(gradient[0].asnumpy(), expect_grad.asnumpy())
@pytest.mark.level0
@ -71,15 +71,16 @@ def test_vjp_multiple_inputs_default_v_graph():
expect_primal_1 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
expect_grad_0 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32))
expect_grad_1 = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
primal, grad = vjp(net, (x, y), (v, v))
primal, grad_fn = vjp(net, x, y)
gradient = grad_fn(v, v)
assert isinstance(primal, tuple)
assert len(primal) == 2
assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
assert np.allclose(primal[1].asnumpy(), expect_primal_1.asnumpy())
assert isinstance(grad, tuple)
assert len(grad) == 2
assert np.allclose(grad[0].asnumpy(), expect_grad_0.asnumpy())
assert np.allclose(grad[1].asnumpy(), expect_grad_1.asnumpy())
assert isinstance(gradient, tuple)
assert len(gradient) == 2
assert np.allclose(gradient[0].asnumpy(), expect_grad_0.asnumpy())
assert np.allclose(gradient[1].asnumpy(), expect_grad_1.asnumpy())
@pytest.mark.level0
@ -97,14 +98,15 @@ def test_vjp_ms_function_single_input_single_output_default_v_graph():
@ms_function
def vjp_with_ms_function(inputs, vectors):
output, vjp_grad = vjp(net, inputs, vectors)
output, grad_fn = vjp(net, inputs)
vjp_grad = grad_fn(vectors)
return output, vjp_grad
primal, grad = vjp_with_ms_function(x, v)
primal, gradient = vjp_with_ms_function(x, v)
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
assert np.allclose(gradient[0].asnumpy(), expect_grad.asnumpy())
@pytest.mark.level0
@ -120,13 +122,14 @@ def test_vjp_input_function_single_input_single_output_default_v_graph():
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
def test_function(inputs):
return inputs**3
return inputs ** 3
primal, grad = vjp(test_function, x, v)
primal, grad_fn = vjp(test_function, x)
gradient = grad_fn(v)
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
assert np.allclose(gradient[0].asnumpy(), expect_grad.asnumpy())
@pytest.mark.level0
@ -147,12 +150,46 @@ def test_vjp_construct_single_input_single_output_default_v_graph():
self.net = network
def construct(self, inputs, vectors):
net_out, vjp_out = vjp(self.net, inputs, vectors)
net_out, grad_fn = vjp(self.net, inputs)
vjp_out = grad_fn(vectors)
return net_out, vjp_out
test_net_graph = Net(SingleInputNet())
primal, grad = test_net_graph(x, v)
primal, gradient = test_net_graph(x, v)
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
assert np.allclose(gradient[0].asnumpy(), expect_grad.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_vjp_multiple_outputs_with_has_aux_graph():
"""
Features: Function vjp
Description: Test vjp with multiple inputs, multiple outputs with set_aux as True in graph mode.
Expectation: No exception.
"""
def fn(x, y):
return 2 * x + y, y ** 3
def fn2(*args):
return fn(*args)
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
expect_primal = Tensor(np.array([[3, 6], [9, 12]]).astype(np.float32))
expect_aux = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
expect_grad_0 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32))
expect_grad_1 = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
primal, grad_fn, aux = vjp(fn2, x, y, has_aux=True)
gradient = grad_fn(v)
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
assert np.allclose(aux.asnumpy(), expect_aux.asnumpy())
assert isinstance(gradient, tuple)
assert len(gradient) == 2
assert np.allclose(gradient[0].asnumpy(), expect_grad_0.asnumpy())
assert np.allclose(gradient[1].asnumpy(), expect_grad_1.asnumpy())

View File

@ -25,12 +25,12 @@ context.set_context(mode=context.PYNATIVE_MODE)
class SingleInputNet(nn.Cell):
def construct(self, x):
return x**3
return x ** 3
class MultipleInputsOutputNet(nn.Cell):
def construct(self, x, y):
return 2*x, y**3
return 2 * x, y ** 3
@pytest.mark.level0
@ -47,9 +47,10 @@ def test_vjp_single_input_pynative():
net = SingleInputNet()
expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
primal, grad = vjp(net, x, v)
primal, grad_fn = vjp(net, x)
gradient = grad_fn(v)
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
assert np.allclose(gradient[0].asnumpy(), expect_grad.asnumpy())
@pytest.mark.level0
@ -69,11 +70,12 @@ def test_vjp_multiple_inputs_default_v_pynative():
expect_grad_1 = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
expect_primal_0 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
expect_primal_1 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
primal, grad = vjp(net, (x, y), (v, v))
assert isinstance(grad, tuple)
assert len(grad) == 2
assert np.allclose(grad[0].asnumpy(), expect_grad_0.asnumpy())
assert np.allclose(grad[1].asnumpy(), expect_grad_1.asnumpy())
primal, grad_fn = vjp(net, x, y)
gradient = grad_fn(v, v)
assert isinstance(gradient, tuple)
assert len(gradient) == 2
assert np.allclose(gradient[0].asnumpy(), expect_grad_0.asnumpy())
assert np.allclose(gradient[1].asnumpy(), expect_grad_1.asnumpy())
assert isinstance(primal, tuple)
assert len(primal) == 2
assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
@ -93,13 +95,14 @@ def test_vjp_input_function_single_input_single_output_default_v_pynative():
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
def test_function(inputs):
return inputs**3
return inputs ** 3
primal, grad = vjp(test_function, x, v)
primal, grad_fn = vjp(test_function, x)
gradient = grad_fn(v)
expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
assert np.allclose(gradient[0].asnumpy(), expect_grad.asnumpy())
@pytest.mark.level0
@ -120,12 +123,46 @@ def test_vjp_construct_single_input_single_output_default_v_pynative():
self.net = network
def construct(self, inputs, vectors):
net_out, vjp_out = vjp(self.net, inputs, vectors)
net_out, grad_fn = vjp(self.net, inputs)
vjp_out = grad_fn(vectors)
return net_out, vjp_out
test_net_pynative = Net(SingleInputNet())
primal, grad = test_net_pynative(x, v)
primal, gradient = test_net_pynative(x, v)
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
assert np.allclose(gradient[0].asnumpy(), expect_grad.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_vjp_multiple_outputs_with_has_aux_pynative():
"""
Features: Function vjp
Description: Test vjp with multiple inputs, multiple outputs with set_aux as True in pynative mode.
Expectation: No exception.
"""
def fn(x, y):
return 2 * x + y, y ** 3
def fn2(*args):
return fn(*args)
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
expect_grad_0 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32))
expect_grad_1 = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
expect_primal = Tensor(np.array([[3, 6], [9, 12]]).astype(np.float32))
expect_aux = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
primal, grad_fn, aux = vjp(fn2, x, y, has_aux=True)
gradient = grad_fn(v)
assert isinstance(gradient, tuple)
assert len(gradient) == 2
assert np.allclose(gradient[0].asnumpy(), expect_grad_0.asnumpy())
assert np.allclose(gradient[1].asnumpy(), expect_grad_1.asnumpy())
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
assert np.allclose(aux.asnumpy(), expect_aux.asnumpy())

View File

@ -25,12 +25,12 @@ context.set_context(mode=context.GRAPH_MODE)
class SingleInputNet(nn.Cell):
def construct(self, x):
return x**3
return x ** 3
class MultipleInputsOutputNet(nn.Cell):
def construct(self, x, y):
return 2*x, y**3
return 2 * x, y ** 3
def test_vjp_single_input_graph():
@ -42,7 +42,7 @@ def test_vjp_single_input_graph():
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
net = SingleInputNet()
vjp(net, x, v)
vjp(net, x)[1](v)
def test_vjp_multiple_inputs_default_v_graph():
@ -55,7 +55,7 @@ def test_vjp_multiple_inputs_default_v_graph():
y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
net = MultipleInputsOutputNet()
vjp(net, (x, y), (v, v))
vjp(net, x, y)[1](v, v)
def test_vjp_wrong_input_type_graph():
@ -68,4 +68,4 @@ def test_vjp_wrong_input_type_graph():
v = 1
net = SingleInputNet()
with pytest.raises(TypeError):
vjp(net, x, v)
vjp(net, x)[1](v)

View File

@ -25,12 +25,12 @@ context.set_context(mode=context.PYNATIVE_MODE)
class SingleInputNet(nn.Cell):
def construct(self, x):
return x**3
return x ** 3
class MultipleInputsOutputNet(nn.Cell):
def construct(self, x, y):
return 2*x, y**3
return 2 * x, y ** 3
def test_vjp_single_input_pynative():
@ -42,7 +42,7 @@ def test_vjp_single_input_pynative():
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
net = SingleInputNet()
vjp(net, x, v)
vjp(net, x)[1](v)
def test_vjp_multiple_inputs_default_v_pynative():
@ -55,7 +55,7 @@ def test_vjp_multiple_inputs_default_v_pynative():
y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
net = MultipleInputsOutputNet()
vjp(net, (x, y), (v, v))
vjp(net, x, y)[1](v, v)
def test_vjp_wrong_input_type_pynative():
@ -68,4 +68,4 @@ def test_vjp_wrong_input_type_pynative():
v = 1
net = SingleInputNet()
with pytest.raises(TypeError):
vjp(net, x, v)
vjp(net, x)[1](v)