add jacfwd

This commit is contained in:
chenzhuo 2022-08-24 14:49:09 +08:00
parent 987685cb31
commit bd01dcad29
8 changed files with 727 additions and 14 deletions

View File

@ -552,6 +552,7 @@ Parameter操作函数
mindspore.ops.jet
mindspore.ops.jvp
mindspore.ops.vjp
mindspore.ops.jacfwd
mindspore.ops.vmap
调试函数

View File

@ -0,0 +1,17 @@
mindspore.ops.jacfwd
====================
.. py:function:: mindspore.ops.jacfwd(fn, inputs, has_aux=False)
通过前向模式计算给定网络的雅可比矩阵,对应 `前向模式自动微分 <https://www.mindspore.cn/docs/zh-CN/master/design/auto_gradient.html#前向自动微分>`_。当网络输出数量远大于输入数量时,使用前向模式求雅可比矩阵比反向模式性能更好。
参数:
- **fn** (Union[Function, Cell]) - 待求导的函数或网络。以Tensor为入参返回Tensor或Tensor数组。
- **grad_position** (Union[NoneType, int, tuple[int]]) - 指定求导输入位置的索引。若为int类型表示对单个输入求导若为tuple类型表示对tuple内索引的位置求导其中索引从0开始。默认值0。
- **has_aux** (bool) - 若 `has_aux` 为True只有 `fn` 的第一个输出参与 `fn` 的求导,其他输出将直接返回。此时, `fn` 的输出数量必须超过一个。默认值False。
返回:
Function用于计算给定函数的雅可比矩阵。例如 `out1, out2 = fn(*args)` ,若 `has_aux` 为True梯度函数将返回 `(Jacobian, out2)` 形式的结果,其中 `out2` 不参与求导若为False将直接返回 `Jacobian`
异常:
- **TypeError** - `grad_position``has_aux` 类型不符合要求。

View File

@ -560,6 +560,7 @@ Differential Functions
mindspore.ops.jet
mindspore.ops.jvp
mindspore.ops.vjp
mindspore.ops.jacfwd
mindspore.ops.vmap
Debugging Functions

View File

@ -373,6 +373,7 @@ from .grad import (
derivative,
jvp,
vjp,
jacfwd,
linearize
)
from .debug_func import (

View File

@ -23,6 +23,7 @@ from .grad_func import (
derivative,
jvp,
vjp,
jacfwd,
linearize
)

View File

@ -14,16 +14,16 @@
# ============================================================================
"""Defines gradient related operators with functional form."""
from __future__ import absolute_import
from functools import partial
import numpy as np
from mindspore.common import ms_function
from mindspore.common import Tensor
from mindspore.common import dtype as mstype
from mindspore.nn.grad.cell_grad import _LinearizeInner
from mindspore.ops.primitive import constexpr
from mindspore.ops.function import ones, expand_dims
from mindspore.ops.composite import _Grad, _TaylorOperation, GradOperation
from mindspore.ops.function.array_func import ones, expand_dims, size, reshape, broadcast_to, transpose
from mindspore.ops.composite import _Vmap, _Grad, _TaylorOperation, GradOperation
from mindspore.ops import operations as P
cast = P.Cast()
@ -48,6 +48,7 @@ def _raise_type_error():
def _convert_grad_position_type(grad_position):
"""Check and convert the type and size of grad position index."""
if isinstance(grad_position, tuple):
grad_position = tuple(set(grad_position))
for gp in grad_position:
if not isinstance(gp, int):
raise TypeError(f"For 'F.grad', the element in 'grad_position' must be int.")
@ -62,6 +63,16 @@ def _convert_grad_position_type(grad_position):
return grad_position
@constexpr
def _check_grad_position(grad_position, args_num):
"""Check and convert grad position index."""
grad_position = _convert_grad_position_type(grad_position)
for gp in grad_position:
if gp < 0 or gp >= args_num:
raise ValueError("The element in grad_position must belong to [0, args_num).")
return grad_position
@constexpr
def _get_grad_op(get_by_list, get_by_position, has_aux, get_value=False):
return _Grad(get_by_list=get_by_list, get_by_position=get_by_position, has_aux=has_aux, get_value=get_value)
@ -547,8 +558,8 @@ def derivative(fn, primals, order):
return out_primals, out_series
_grad_jvp_single = GradOperation(sens_param=True)
_grad_jvp_all = GradOperation(sens_param=True, get_all=True)
_grad_single = GradOperation(sens_param=True)
_grad_all = GradOperation(sens_param=True, get_all=True)
def jvp(fn, inputs, v, has_aux=False):
@ -626,10 +637,10 @@ def jvp(fn, inputs, v, has_aux=False):
fn_ = fn
def grad_single(u, first_grad_single_value):
return _grad_jvp_single(fn_)(*first_grad_single_value, u)
return _grad_single(fn_)(*first_grad_single_value, u)
def grad_all(u, first_grad):
return _grad_jvp_all(fn_)(*first_grad, u)
return _grad_all(fn_)(*first_grad, u)
@ms_function(hash_args=fn_)
def _wrap_container(*arg):
@ -643,10 +654,10 @@ def jvp(fn, inputs, v, has_aux=False):
else:
u = oneslike(outputs)
if len(jvp_inputs) == 1:
second_grad_net = _grad_jvp_single(grad_single)
second_grad_net = _grad_single(grad_single)
gradient_outputs = second_grad_net(u, jvp_inputs, vectors)
else:
second_grad_net = _grad_jvp_single(grad_all)
second_grad_net = _grad_single(grad_all)
gradient_outputs = second_grad_net(u, jvp_inputs, vectors)
if has_aux:
res = fn(*jvp_inputs)
@ -741,9 +752,6 @@ def _check_tensor(inputs):
return True
vjp_grad = GradOperation(get_all=True, sens_param=True)
def vjp(fn, *inputs, has_aux=False):
"""
Compute the vector-jacobian-product of the given network. `vjp` matches
@ -830,8 +838,8 @@ def vjp(fn, *inputs, has_aux=False):
def wrap_container(*v):
_check_tensor(v)
if len(v) == 1:
return vjp_grad(fn_)(*inputs, v[0])
return vjp_grad(fn_)(*inputs, v)
return _grad_all(fn_)(*inputs, v[0])
return _grad_all(fn_)(*inputs, v)
res = fn(*inputs)
if has_aux:
@ -841,6 +849,201 @@ def vjp(fn, *inputs, has_aux=False):
return res, wrap_container
@constexpr
def _jac_generate_target_dimension(x):
"""For given length = len(x), this method generates target dimension tuple (1, 2, 3,..., length, 0)."""
target_dimension = tuple(index + 1 for index, _ in enumerate(x[1:])) + (0,)
return target_dimension
def _jac_trans_item(item, inputs_shape, grad_position):
"""transfer origin item to derivative of each output with respect to each input."""
output_wrt_input_all = ()
length = len(inputs_shape) - 1
for i in range(length):
if i in grad_position:
origin_output_wrt_input = item[inputs_shape[i][1]:inputs_shape[i + 1][1]]
target_dimension = _jac_generate_target_dimension(origin_output_wrt_input.shape)
temp = transpose(origin_output_wrt_input, target_dimension)
output_wrt_input = reshape(temp, temp.shape[:-1] + inputs_shape[i + 1][0])
output_wrt_input_all += (output_wrt_input,)
return output_wrt_input_all
def _jacfwd_postprocess(x, inputs_shape, grad_position):
"""reformat jacobian."""
if isinstance(x, tuple):
jacobian = ()
for item in x:
jacobian += _jac_trans_item(item, inputs_shape, grad_position)
res = jacobian
else:
res = _jac_trans_item(x, inputs_shape, grad_position)
if len(res) == 1:
return res[0]
input_num = len(grad_position)
if len(res) % input_num != 0:
raise ValueError("The numbers of inputs and outputs do not match.")
output_num = len(res) // input_num
if input_num == 1 or output_num == 1:
return res
jac = ()
for i in range(output_num):
input_grad = ()
for j in range(input_num):
input_grad += (res[i * input_num + j],)
jac += (input_grad,)
return jac
def _jacfwd_construct_v(inputs, grad_position):
"""
For input (x, y), x.shape = (a, b), y.shape = (c, d), this method generates corresponding v (v1, v2),
v1.shape = (N, a, b), v2.shape = (N, c, d), while N = a*b + c*d.
"""
v = ()
primals = ()
inputs_shape = (((), 0),)
num = 0
items_num = ()
cum_num = (0,)
for item in inputs:
item_num = size(item)
num += item_num
inputs_shape += ((item.shape, num),)
items_num += (item_num,)
cum_num += (num,)
for i, element in enumerate(inputs):
item_size = items_num[i]
if i in grad_position:
temp2 = Tensor(np.eye(num, item_size, -cum_num[i], np.float32))
else:
temp2 = zeros((num, item_size), mstype.float32)
input_v = reshape(temp2, (num,) + element.shape)
primal = broadcast_to(element, (num,) + element.shape)
v += (input_v,)
primals += (primal,)
if len(inputs) == 1:
return primals, v[0], inputs_shape
return primals, v, inputs_shape
_vmap = _Vmap()
def jacfwd(fn, grad_position=0, has_aux=False):
"""
Compute Jacobian via forward mode, corresponding to
`forward-mode differentiation <https://www.mindspore.cn/docs/en/master/design/auto_gradient.html#forward-mode-ad>`_.
When number of outputs is much greater than that of inputs, it's better to calculate Jacobian via forward mode than
reverse mode to get better performance.
Args:
fn (Union(Cell, function)): Function to do GradOperation.
grad_position (Union(int, tuple[int])): If int, get the gradient with respect to single input.
If tuple, get the gradients with respect to selected inputs. 'grad_position' begins with 0. Default: 0.
has_aux (bool): If True, only the first output of `fn` contributes the gradient of `fn`, while the other outputs
will be returned straightly. It means the `fn` must return more than one outputs in this case.
Default: False.
Returns:
Function, returns the Jacobian function for the input function or cell.
For example, as for `out1, out2 = fn(*args)`, when `has_aux` is set True, gradient function will return outputs
like `(Jacobian, out2)` and `out2` does not contribute to the differentiation, otherwise `Jacobian` .
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import numpy as np
>>> import mindspore.nn as nn
>>> from mindspore.ops import jacfwd
>>> from mindspore import Tensor
>>> class MultipleInputsMultipleOutputsNet(nn.Cell):
... def construct(self, x, y, z):
... return x ** 2 + y ** 2 + z ** 2, x * y * z
>>> x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
>>> y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
>>> z = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
>>> net = MultipleInputsMultipleOutputsNet()
>>> jac, aux = jacfwd(net, grad_position=0, has_aux=True)(x, y, z)
>>> print(jac)
Tensor(shape=[2, 2, 2, 2], dtype=Float32, value=
[[[[ 2.00000000e+00, 0.00000000e+00],
[ 0.00000000e+00, 0.00000000e+00]],
[[ 0.00000000e+00, 4.00000000e+00],
[ 0.00000000e+00, 0.00000000e+00]]],
[[[ 0.00000000e+00, 0.00000000e+00],
[ 6.00000000e+00, 0.00000000e+00]],
[[ 0.00000000e+00, 0.00000000e+00],
[ 0.00000000e+00, 8.00000000e+00]]]])
>>> print(aux)
[[ 1. 4.]
[ 9. 16.]]
"""
_check_has_aux_type(has_aux)
def aux_fn(*args):
outputs = fn(*args)
if not isinstance(outputs, tuple) or len(outputs) < 2:
raise ValueError("When 'has_aux' is True, origin 'fn' requires more than one outputs.")
res = outputs[0]
return res
def grad_single(u, first_grad_single_value):
if has_aux:
return _grad_single(aux_fn)(*first_grad_single_value, u)
return _grad_single(fn)(*first_grad_single_value, u)
def grad_all(u, first_grad):
if has_aux:
return _grad_all(aux_fn)(*first_grad, u)
return _grad_all(fn)(*first_grad, u)
@ms_function
def wrapped(*args):
checked_grad_position = _check_grad_position(grad_position, len(args))
primals, v, inputs_shape = _jacfwd_construct_v(args, checked_grad_position)
def inner_fn(jvp_inputs, vectors):
outputs = fn(*jvp_inputs)
if isinstance(outputs, tuple):
u = ()
for item in outputs:
u = u + (oneslike(item),)
else:
u = oneslike(outputs)
if len(jvp_inputs) == 1:
second_grad_net = _grad_single(grad_single)
else:
second_grad_net = _grad_single(grad_all)
gradient_outputs = second_grad_net(u, jvp_inputs, vectors)
return gradient_outputs
def inner_aux_fn(jvp_inputs, vectors):
outputs = aux_fn(*jvp_inputs)
u = oneslike(outputs)
if len(jvp_inputs) == 1:
second_grad_net = _grad_single(grad_single)
else:
second_grad_net = _grad_single(grad_all)
gradient_outputs = second_grad_net(u, jvp_inputs, vectors)
return gradient_outputs
if has_aux:
res = _vmap(inner_aux_fn)(primals, v)
jac_res = _jacfwd_postprocess(res, inputs_shape, checked_grad_position)
forward_outputs = fn(*args)
if len(forward_outputs) == 2:
return jac_res, forward_outputs[1]
return jac_res, forward_outputs[1:]
res = _vmap(inner_fn)(primals, v)
jac_res = _jacfwd_postprocess(res, inputs_shape, checked_grad_position)
return jac_res
return wrapped
__all__ = [
'grad',
'value_and_grad',
@ -848,6 +1051,7 @@ __all__ = [
'derivative',
'jvp',
'vjp',
'jacfwd',
'linearize'
]
__all__.sort()

View File

@ -0,0 +1,244 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test function jacfwd in graph mode"""
import numpy as np
import pytest
import mindspore.nn as nn
import mindspore.context as context
from mindspore import Tensor
from mindspore import ms_function
from mindspore.ops import jacfwd
context.set_context(mode=context.GRAPH_MODE)
class SingleInputSingleOutputNet(nn.Cell):
def construct(self, x):
return x ** 3
class SingleInputMultipleOutputsNet(nn.Cell):
def construct(self, x):
return x ** 3, 2 * x
class MultipleInputsSingleOutputNet(nn.Cell):
def construct(self, x, y, z):
return x * y * z
class MultipleInputsMultipleOutputsNet(nn.Cell):
def construct(self, x, y, z):
return x ** 2 + y ** 2 + z ** 2, x * y * z
def function(x, y, z):
return x ** 2 + y ** 2 + z ** 2, x * y * z
def iteration_jac_function(x, y, z):
return x ** 2 * y * z
@ms_function
def jac_wrap_with_ms_function(x, y, z):
output = jacfwd(function, has_aux=True)(x, y, z)
return output
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_jac_single_input_single_output_cell_graph():
"""
Features: Function jacfwd.
Description: Test ops.jacfwd with single input and single output net in graph mode.
Expectation: No exception.
"""
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
net = SingleInputSingleOutputNet()
expect_jac = np.array([[[[3, 0], [0, 0]], [[0, 12], [0, 0]]],
[[[0, 0], [27, 0]], [[0, 0], [0, 48]]]]).astype(np.float32)
jac = jacfwd(net)(x)
assert np.allclose(jac.asnumpy(), expect_jac)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_jac_single_input_multiple_outputs_cell_graph():
"""
Features: Function jacfwd.
Description: Test ops.jacfwd with single input and multiple outputs net in graph mode.
Expectation: No exception.
"""
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
net = SingleInputMultipleOutputsNet()
expect_jac_0 = np.array([[[[3, 0], [0, 0]], [[0, 12], [0, 0]]],
[[[0, 0], [27, 0]], [[0, 0], [0, 48]]]]).astype(np.float32)
expect_jac_1 = np.array([[[[2, 0], [0, 0]], [[0, 2], [0, 0]]],
[[[0, 0], [2, 0]], [[0, 0], [0, 2]]]]).astype(np.float32)
jac = jacfwd(net)(x)
assert np.allclose(jac[0].asnumpy(), expect_jac_0)
assert np.allclose(jac[1].asnumpy(), expect_jac_1)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_jac_multiple_inputs_single_output_cell_graph():
"""
Features: Function jacfwd.
Description: Test ops.jacfwd with multiple inputs and single output net in graph mode.
Expectation: No exception.
"""
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
y = Tensor(np.array([[-2, 3], [-1, 2]]).astype(np.float32))
z = Tensor(np.array([[0, 3], [5, -1]]).astype(np.float32))
net = MultipleInputsSingleOutputNet()
expect_jac_0 = np.array([[[[0, 0], [0, 0]], [[0, 6], [0, 0]]],
[[[0, 0], [15, 0]], [[0, 0], [0, -4]]]]).astype(np.float32)
expect_jac_1 = np.array([[[[-2, 0], [0, 0]], [[0, 6], [0, 0]]],
[[[0, 0], [-3, 0]], [[0, 0], [0, 8]]]]).astype(np.float32)
jac = jacfwd(net, grad_position=(1, 2))(x, y, z)
assert isinstance(jac, tuple)
assert len(jac) == 2
assert np.allclose(jac[0].asnumpy(), expect_jac_0)
assert np.allclose(jac[1].asnumpy(), expect_jac_1)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_jac_multiple_inputs_multiple_outputs_cell_graph():
"""
Features: Function jacfwd.
Description: Test ops.jacfwd with multiple inputs and multiple outputs net in graph mode.
Expectation: No exception.
"""
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
y = Tensor(np.array([[-2, 3], [-1, 2]]).astype(np.float32))
z = Tensor(np.array([[0, 3], [5, -1]]).astype(np.float32))
net = MultipleInputsMultipleOutputsNet()
expect_jac_0 = np.array([[[[-4, 0], [0, 0]], [[0, 6], [0, 0]]],
[[[0, 0], [-2, 0]], [[0, 0], [0, 4]]]]).astype(np.float32)
expect_jac_1 = np.array([[[[0, 0], [0, 0]], [[0, 6], [0, 0]]],
[[[0, 0], [10, 0]], [[0, 0], [0, -2]]]]).astype(np.float32)
expect_jac_2 = np.array([[[[0, 0], [0, 0]], [[0, 6], [0, 0]]],
[[[0, 0], [15, 0]], [[0, 0], [0, -4]]]]).astype(np.float32)
expect_jac_3 = np.array([[[[-2, 0], [0, 0]], [[0, 6], [0, 0]]],
[[[0, 0], [-3, 0]], [[0, 0], [0, 8]]]]).astype(np.float32)
jac = jacfwd(net, grad_position=(1, 2))(x, y, z)
assert isinstance(jac, tuple)
assert len(jac) == 2
assert np.allclose(jac[0][0].asnumpy(), expect_jac_0)
assert np.allclose(jac[0][1].asnumpy(), expect_jac_1)
assert np.allclose(jac[1][0].asnumpy(), expect_jac_2)
assert np.allclose(jac[1][1].asnumpy(), expect_jac_3)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_jac_wrap_with_ms_function_graph():
"""
Features: Function jacfwd.
Description: Test ops.jacfwd warpped with ms_function in graph mode.
Expectation: No exception.
"""
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
y = Tensor(np.array([[-2, 3], [-1, 2]]).astype(np.float32))
z = Tensor(np.array([[0, 3], [5, -1]]).astype(np.float32))
expect_jac = np.array([[[[2, 0], [0, 0]], [[0, 4], [0, 0]]],
[[[0, 0], [6, 0]], [[0, 0], [0, 8]]]]).astype(np.float32)
expect_aux = np.array([[0, 18], [-15, -8]]).astype(np.float32)
jac, aux = jac_wrap_with_ms_function(x, y, z)
assert np.allclose(jac.asnumpy(), expect_jac)
assert np.allclose(aux.asnumpy(), expect_aux)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_jac_with_grad_position_twice_graph():
"""
Features: Function jacfwd.
Description: Test ops.jacfwd with function setting grad_position twice in graph mode.
Expectation: No exception.
"""
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
y = Tensor(np.array([[1, 3], [5, 7]]).astype(np.float32))
z = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
expect_jac_0 = np.array([[[[1, 0], [0, 0]], [[0, 3], [0, 0]]],
[[[0, 0], [5, 0]], [[0, 0], [0, 7]]]]).astype(np.float32)
expect_jac_1 = np.array([[[[1, 0], [0, 0]], [[0, 2], [0, 0]]],
[[[0, 0], [3, 0]], [[0, 0], [0, 4]]]]).astype(np.float32)
net = MultipleInputsSingleOutputNet()
jac1 = jacfwd(net, grad_position=0)(x, y, z)
jac2 = jacfwd(net, grad_position=(0, 1))(x, y, z)
assert np.allclose(jac1.asnumpy(), expect_jac_0)
assert np.allclose(jac2[1].asnumpy(), expect_jac_1)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_jac_with_has_aux_graph():
"""
Features: Function jacfwd.
Description: Test ops.jacfwd with Cell setting grad_position in graph mode.
Expectation: No exception.
"""
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
z = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
expect_jac = np.array([[[[2, 0], [0, 0]], [[0, 4], [0, 0]]],
[[[0, 0], [6, 0]], [[0, 0], [0, 8]]]]).astype(np.float32)
expect_aux = np.array([[1, 4], [9, 16]]).astype(np.float32)
net = MultipleInputsMultipleOutputsNet()
jac, aux = jacfwd(net, grad_position=0, has_aux=True)(x, y, z)
assert np.allclose(jac.asnumpy(), expect_jac)
assert np.allclose(aux.asnumpy(), expect_aux)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_jac_with_function_has_aux_graph():
"""
Features: Function jacfwd.
Description: Test ops.jacfwd with function setting grad_position in graph mode.
Expectation: No exception.
"""
def fn(x, y, z):
return x ** 2 + y ** 2 + z ** 2, x * y * z
def fn2(*args):
x = args[0]
y = args[1]
z = args[2]
return fn(x, y, z)
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
z = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
expect_jac = np.array([[[[2, 0], [0, 0]], [[0, 4], [0, 0]]],
[[[0, 0], [6, 0]], [[0, 0], [0, 8]]]]).astype(np.float32)
expect_aux = np.array([[1, 4], [9, 16]]).astype(np.float32)
jac, aux = jacfwd(fn2, grad_position=0, has_aux=True)(x, y, z)
assert np.allclose(jac.asnumpy(), expect_jac)
assert np.allclose(aux.asnumpy(), expect_aux)

View File

@ -0,0 +1,244 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test function jacfwd in pynative mode"""
import numpy as np
import pytest
import mindspore.nn as nn
import mindspore.context as context
from mindspore import Tensor
from mindspore import ms_function
from mindspore.ops import jacfwd
context.set_context(mode=context.PYNATIVE_MODE)
class SingleInputSingleOutputNet(nn.Cell):
def construct(self, x):
return x ** 3
class SingleInputMultipleOutputsNet(nn.Cell):
def construct(self, x):
return x ** 3, 2 * x
class MultipleInputsSingleOutputNet(nn.Cell):
def construct(self, x, y, z):
return x * y * z
class MultipleInputsMultipleOutputsNet(nn.Cell):
def construct(self, x, y, z):
return x ** 2 + y ** 2 + z ** 2, x * y * z
def function(x, y, z):
return x ** 2 + y ** 2 + z ** 2, x * y * z
def iteration_jac_function(x, y, z):
return x ** 2 * y * z
@ms_function
def jac_wrap_with_ms_function(x, y, z):
output = jacfwd(function, has_aux=True)(x, y, z)
return output
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_jac_single_input_single_output_cell_pynative():
"""
Features: Function jacfwd.
Description: Test ops.jacfwd with single input and single output net in pynative mode.
Expectation: No exception.
"""
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
net = SingleInputSingleOutputNet()
expect_jac = np.array([[[[3, 0], [0, 0]], [[0, 12], [0, 0]]],
[[[0, 0], [27, 0]], [[0, 0], [0, 48]]]]).astype(np.float32)
jac = jacfwd(net)(x)
assert np.allclose(jac.asnumpy(), expect_jac)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_jac_single_input_multiple_outputs_cell_pynative():
"""
Features: Function jacfwd.
Description: Test ops.jacfwd with single input and multiple outputs net in pynative mode.
Expectation: No exception.
"""
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
net = SingleInputMultipleOutputsNet()
expect_jac_0 = np.array([[[[3, 0], [0, 0]], [[0, 12], [0, 0]]],
[[[0, 0], [27, 0]], [[0, 0], [0, 48]]]]).astype(np.float32)
expect_jac_1 = np.array([[[[2, 0], [0, 0]], [[0, 2], [0, 0]]],
[[[0, 0], [2, 0]], [[0, 0], [0, 2]]]]).astype(np.float32)
jac = jacfwd(net)(x)
assert np.allclose(jac[0].asnumpy(), expect_jac_0)
assert np.allclose(jac[1].asnumpy(), expect_jac_1)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_jac_multiple_inputs_single_output_cell_pynative():
"""
Features: Function jacfwd.
Description: Test ops.jacfwd with multiple inputs and single output net in pynative mode.
Expectation: No exception.
"""
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
y = Tensor(np.array([[-2, 3], [-1, 2]]).astype(np.float32))
z = Tensor(np.array([[0, 3], [5, -1]]).astype(np.float32))
net = MultipleInputsSingleOutputNet()
expect_jac_0 = np.array([[[[0, 0], [0, 0]], [[0, 6], [0, 0]]],
[[[0, 0], [15, 0]], [[0, 0], [0, -4]]]]).astype(np.float32)
expect_jac_1 = np.array([[[[-2, 0], [0, 0]], [[0, 6], [0, 0]]],
[[[0, 0], [-3, 0]], [[0, 0], [0, 8]]]]).astype(np.float32)
jac = jacfwd(net, grad_position=(1, 2))(x, y, z)
assert isinstance(jac, tuple)
assert len(jac) == 2
assert np.allclose(jac[0].asnumpy(), expect_jac_0)
assert np.allclose(jac[1].asnumpy(), expect_jac_1)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_jac_multiple_inputs_multiple_outputs_cell_pynative():
"""
Features: Function jacfwd.
Description: Test ops.jacfwd with multiple inputs and multiple outputs net in pynative mode.
Expectation: No exception.
"""
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
y = Tensor(np.array([[-2, 3], [-1, 2]]).astype(np.float32))
z = Tensor(np.array([[0, 3], [5, -1]]).astype(np.float32))
net = MultipleInputsMultipleOutputsNet()
expect_jac_0 = np.array([[[[-4, 0], [0, 0]], [[0, 6], [0, 0]]],
[[[0, 0], [-2, 0]], [[0, 0], [0, 4]]]]).astype(np.float32)
expect_jac_1 = np.array([[[[0, 0], [0, 0]], [[0, 6], [0, 0]]],
[[[0, 0], [10, 0]], [[0, 0], [0, -2]]]]).astype(np.float32)
expect_jac_2 = np.array([[[[0, 0], [0, 0]], [[0, 6], [0, 0]]],
[[[0, 0], [15, 0]], [[0, 0], [0, -4]]]]).astype(np.float32)
expect_jac_3 = np.array([[[[-2, 0], [0, 0]], [[0, 6], [0, 0]]],
[[[0, 0], [-3, 0]], [[0, 0], [0, 8]]]]).astype(np.float32)
jac = jacfwd(net, grad_position=(1, 2))(x, y, z)
assert isinstance(jac, tuple)
assert len(jac) == 2
assert np.allclose(jac[0][0].asnumpy(), expect_jac_0)
assert np.allclose(jac[0][1].asnumpy(), expect_jac_1)
assert np.allclose(jac[1][0].asnumpy(), expect_jac_2)
assert np.allclose(jac[1][1].asnumpy(), expect_jac_3)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_jac_wrap_with_ms_function_pynative():
"""
Features: Function jacfwd.
Description: Test ops.jacfwd warpped with ms_function in pynative mode.
Expectation: No exception.
"""
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
y = Tensor(np.array([[-2, 3], [-1, 2]]).astype(np.float32))
z = Tensor(np.array([[0, 3], [5, -1]]).astype(np.float32))
expect_jac = np.array([[[[2, 0], [0, 0]], [[0, 4], [0, 0]]],
[[[0, 0], [6, 0]], [[0, 0], [0, 8]]]]).astype(np.float32)
expect_aux = np.array([[0, 18], [-15, -8]]).astype(np.float32)
jac, aux = jac_wrap_with_ms_function(x, y, z)
assert np.allclose(jac.asnumpy(), expect_jac)
assert np.allclose(aux.asnumpy(), expect_aux)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_jac_with_grad_position_twice_pynative():
"""
Features: Function jacfwd.
Description: Test ops.jacfwd with function setting grad_position twice in pynative mode.
Expectation: No exception.
"""
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
y = Tensor(np.array([[1, 3], [5, 7]]).astype(np.float32))
z = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
expect_jac_0 = np.array([[[[1, 0], [0, 0]], [[0, 3], [0, 0]]],
[[[0, 0], [5, 0]], [[0, 0], [0, 7]]]]).astype(np.float32)
expect_jac_1 = np.array([[[[1, 0], [0, 0]], [[0, 2], [0, 0]]],
[[[0, 0], [3, 0]], [[0, 0], [0, 4]]]]).astype(np.float32)
net = MultipleInputsSingleOutputNet()
jac1 = jacfwd(net, grad_position=0)(x, y, z)
jac2 = jacfwd(net, grad_position=(0, 1))(x, y, z)
assert np.allclose(jac1.asnumpy(), expect_jac_0)
assert np.allclose(jac2[1].asnumpy(), expect_jac_1)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_jac_with_has_aux_pynative():
"""
Features: Function jacfwd.
Description: Test ops.jacfwd with Cell setting grad_position in pynative mode.
Expectation: No exception.
"""
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
z = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
expect_jac = np.array([[[[2, 0], [0, 0]], [[0, 4], [0, 0]]],
[[[0, 0], [6, 0]], [[0, 0], [0, 8]]]]).astype(np.float32)
expect_aux = np.array([[1, 4], [9, 16]]).astype(np.float32)
net = MultipleInputsMultipleOutputsNet()
jac, aux = jacfwd(net, grad_position=0, has_aux=True)(x, y, z)
assert np.allclose(jac.asnumpy(), expect_jac)
assert np.allclose(aux.asnumpy(), expect_aux)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_jac_with_function_has_aux_pynative():
"""
Features: Function jacfwd.
Description: Test ops.jacfwd with function setting grad_position in pynative mode.
Expectation: No exception.
"""
def fn(x, y, z):
return x ** 2 + y ** 2 + z ** 2, x * y * z
def fn2(*args):
x = args[0]
y = args[1]
z = args[2]
return fn(x, y, z)
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
z = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
expect_jac = np.array([[[[2, 0], [0, 0]], [[0, 4], [0, 0]]],
[[[0, 0], [6, 0]], [[0, 0], [0, 8]]]]).astype(np.float32)
expect_aux = np.array([[1, 4], [9, 16]]).astype(np.float32)
jac, aux = jacfwd(fn2, grad_position=0, has_aux=True)(x, y, z)
assert np.allclose(jac.asnumpy(), expect_jac)
assert np.allclose(aux.asnumpy(), expect_aux)