forked from mindspore-Ecosystem/mindspore
add jacfwd
This commit is contained in:
parent
987685cb31
commit
bd01dcad29
|
@ -552,6 +552,7 @@ Parameter操作函数
|
|||
mindspore.ops.jet
|
||||
mindspore.ops.jvp
|
||||
mindspore.ops.vjp
|
||||
mindspore.ops.jacfwd
|
||||
mindspore.ops.vmap
|
||||
|
||||
调试函数
|
||||
|
|
|
@ -0,0 +1,17 @@
|
|||
mindspore.ops.jacfwd
|
||||
====================
|
||||
|
||||
.. py:function:: mindspore.ops.jacfwd(fn, inputs, has_aux=False)
|
||||
|
||||
通过前向模式计算给定网络的雅可比矩阵,对应 `前向模式自动微分 <https://www.mindspore.cn/docs/zh-CN/master/design/auto_gradient.html#前向自动微分>`_。当网络输出数量远大于输入数量时,使用前向模式求雅可比矩阵比反向模式性能更好。
|
||||
|
||||
参数:
|
||||
- **fn** (Union[Function, Cell]) - 待求导的函数或网络。以Tensor为入参,返回Tensor或Tensor数组。
|
||||
- **grad_position** (Union[NoneType, int, tuple[int]]) - 指定求导输入位置的索引。若为int类型,表示对单个输入求导;若为tuple类型,表示对tuple内索引的位置求导,其中索引从0开始。默认值:0。
|
||||
- **has_aux** (bool) - 若 `has_aux` 为True,只有 `fn` 的第一个输出参与 `fn` 的求导,其他输出将直接返回。此时, `fn` 的输出数量必须超过一个。默认值:False。
|
||||
|
||||
返回:
|
||||
Function,用于计算给定函数的雅可比矩阵。例如 `out1, out2 = fn(*args)` ,若 `has_aux` 为True,梯度函数将返回 `(Jacobian, out2)` 形式的结果,其中 `out2` 不参与求导,若为False,将直接返回 `Jacobian` 。
|
||||
|
||||
异常:
|
||||
- **TypeError** - `grad_position` 或 `has_aux` 类型不符合要求。
|
|
@ -560,6 +560,7 @@ Differential Functions
|
|||
mindspore.ops.jet
|
||||
mindspore.ops.jvp
|
||||
mindspore.ops.vjp
|
||||
mindspore.ops.jacfwd
|
||||
mindspore.ops.vmap
|
||||
|
||||
Debugging Functions
|
||||
|
|
|
@ -373,6 +373,7 @@ from .grad import (
|
|||
derivative,
|
||||
jvp,
|
||||
vjp,
|
||||
jacfwd,
|
||||
linearize
|
||||
)
|
||||
from .debug_func import (
|
||||
|
|
|
@ -23,6 +23,7 @@ from .grad_func import (
|
|||
derivative,
|
||||
jvp,
|
||||
vjp,
|
||||
jacfwd,
|
||||
linearize
|
||||
)
|
||||
|
||||
|
|
|
@ -14,16 +14,16 @@
|
|||
# ============================================================================
|
||||
|
||||
"""Defines gradient related operators with functional form."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from functools import partial
|
||||
import numpy as np
|
||||
from mindspore.common import ms_function
|
||||
from mindspore.common import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.nn.grad.cell_grad import _LinearizeInner
|
||||
from mindspore.ops.primitive import constexpr
|
||||
from mindspore.ops.function import ones, expand_dims
|
||||
from mindspore.ops.composite import _Grad, _TaylorOperation, GradOperation
|
||||
from mindspore.ops.function.array_func import ones, expand_dims, size, reshape, broadcast_to, transpose
|
||||
from mindspore.ops.composite import _Vmap, _Grad, _TaylorOperation, GradOperation
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
cast = P.Cast()
|
||||
|
@ -48,6 +48,7 @@ def _raise_type_error():
|
|||
def _convert_grad_position_type(grad_position):
|
||||
"""Check and convert the type and size of grad position index."""
|
||||
if isinstance(grad_position, tuple):
|
||||
grad_position = tuple(set(grad_position))
|
||||
for gp in grad_position:
|
||||
if not isinstance(gp, int):
|
||||
raise TypeError(f"For 'F.grad', the element in 'grad_position' must be int.")
|
||||
|
@ -62,6 +63,16 @@ def _convert_grad_position_type(grad_position):
|
|||
return grad_position
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_grad_position(grad_position, args_num):
|
||||
"""Check and convert grad position index."""
|
||||
grad_position = _convert_grad_position_type(grad_position)
|
||||
for gp in grad_position:
|
||||
if gp < 0 or gp >= args_num:
|
||||
raise ValueError("The element in grad_position must belong to [0, args_num).")
|
||||
return grad_position
|
||||
|
||||
|
||||
@constexpr
|
||||
def _get_grad_op(get_by_list, get_by_position, has_aux, get_value=False):
|
||||
return _Grad(get_by_list=get_by_list, get_by_position=get_by_position, has_aux=has_aux, get_value=get_value)
|
||||
|
@ -547,8 +558,8 @@ def derivative(fn, primals, order):
|
|||
return out_primals, out_series
|
||||
|
||||
|
||||
_grad_jvp_single = GradOperation(sens_param=True)
|
||||
_grad_jvp_all = GradOperation(sens_param=True, get_all=True)
|
||||
_grad_single = GradOperation(sens_param=True)
|
||||
_grad_all = GradOperation(sens_param=True, get_all=True)
|
||||
|
||||
|
||||
def jvp(fn, inputs, v, has_aux=False):
|
||||
|
@ -626,10 +637,10 @@ def jvp(fn, inputs, v, has_aux=False):
|
|||
fn_ = fn
|
||||
|
||||
def grad_single(u, first_grad_single_value):
|
||||
return _grad_jvp_single(fn_)(*first_grad_single_value, u)
|
||||
return _grad_single(fn_)(*first_grad_single_value, u)
|
||||
|
||||
def grad_all(u, first_grad):
|
||||
return _grad_jvp_all(fn_)(*first_grad, u)
|
||||
return _grad_all(fn_)(*first_grad, u)
|
||||
|
||||
@ms_function(hash_args=fn_)
|
||||
def _wrap_container(*arg):
|
||||
|
@ -643,10 +654,10 @@ def jvp(fn, inputs, v, has_aux=False):
|
|||
else:
|
||||
u = oneslike(outputs)
|
||||
if len(jvp_inputs) == 1:
|
||||
second_grad_net = _grad_jvp_single(grad_single)
|
||||
second_grad_net = _grad_single(grad_single)
|
||||
gradient_outputs = second_grad_net(u, jvp_inputs, vectors)
|
||||
else:
|
||||
second_grad_net = _grad_jvp_single(grad_all)
|
||||
second_grad_net = _grad_single(grad_all)
|
||||
gradient_outputs = second_grad_net(u, jvp_inputs, vectors)
|
||||
if has_aux:
|
||||
res = fn(*jvp_inputs)
|
||||
|
@ -741,9 +752,6 @@ def _check_tensor(inputs):
|
|||
return True
|
||||
|
||||
|
||||
vjp_grad = GradOperation(get_all=True, sens_param=True)
|
||||
|
||||
|
||||
def vjp(fn, *inputs, has_aux=False):
|
||||
"""
|
||||
Compute the vector-jacobian-product of the given network. `vjp` matches
|
||||
|
@ -830,8 +838,8 @@ def vjp(fn, *inputs, has_aux=False):
|
|||
def wrap_container(*v):
|
||||
_check_tensor(v)
|
||||
if len(v) == 1:
|
||||
return vjp_grad(fn_)(*inputs, v[0])
|
||||
return vjp_grad(fn_)(*inputs, v)
|
||||
return _grad_all(fn_)(*inputs, v[0])
|
||||
return _grad_all(fn_)(*inputs, v)
|
||||
|
||||
res = fn(*inputs)
|
||||
if has_aux:
|
||||
|
@ -841,6 +849,201 @@ def vjp(fn, *inputs, has_aux=False):
|
|||
return res, wrap_container
|
||||
|
||||
|
||||
@constexpr
|
||||
def _jac_generate_target_dimension(x):
|
||||
"""For given length = len(x), this method generates target dimension tuple (1, 2, 3,..., length, 0)."""
|
||||
target_dimension = tuple(index + 1 for index, _ in enumerate(x[1:])) + (0,)
|
||||
return target_dimension
|
||||
|
||||
|
||||
def _jac_trans_item(item, inputs_shape, grad_position):
|
||||
"""transfer origin item to derivative of each output with respect to each input."""
|
||||
output_wrt_input_all = ()
|
||||
length = len(inputs_shape) - 1
|
||||
for i in range(length):
|
||||
if i in grad_position:
|
||||
origin_output_wrt_input = item[inputs_shape[i][1]:inputs_shape[i + 1][1]]
|
||||
target_dimension = _jac_generate_target_dimension(origin_output_wrt_input.shape)
|
||||
temp = transpose(origin_output_wrt_input, target_dimension)
|
||||
output_wrt_input = reshape(temp, temp.shape[:-1] + inputs_shape[i + 1][0])
|
||||
output_wrt_input_all += (output_wrt_input,)
|
||||
return output_wrt_input_all
|
||||
|
||||
|
||||
def _jacfwd_postprocess(x, inputs_shape, grad_position):
|
||||
"""reformat jacobian."""
|
||||
if isinstance(x, tuple):
|
||||
jacobian = ()
|
||||
for item in x:
|
||||
jacobian += _jac_trans_item(item, inputs_shape, grad_position)
|
||||
res = jacobian
|
||||
else:
|
||||
res = _jac_trans_item(x, inputs_shape, grad_position)
|
||||
if len(res) == 1:
|
||||
return res[0]
|
||||
input_num = len(grad_position)
|
||||
if len(res) % input_num != 0:
|
||||
raise ValueError("The numbers of inputs and outputs do not match.")
|
||||
output_num = len(res) // input_num
|
||||
if input_num == 1 or output_num == 1:
|
||||
return res
|
||||
jac = ()
|
||||
for i in range(output_num):
|
||||
input_grad = ()
|
||||
for j in range(input_num):
|
||||
input_grad += (res[i * input_num + j],)
|
||||
jac += (input_grad,)
|
||||
return jac
|
||||
|
||||
|
||||
def _jacfwd_construct_v(inputs, grad_position):
|
||||
"""
|
||||
For input (x, y), x.shape = (a, b), y.shape = (c, d), this method generates corresponding v (v1, v2),
|
||||
v1.shape = (N, a, b), v2.shape = (N, c, d), while N = a*b + c*d.
|
||||
"""
|
||||
v = ()
|
||||
primals = ()
|
||||
inputs_shape = (((), 0),)
|
||||
num = 0
|
||||
items_num = ()
|
||||
cum_num = (0,)
|
||||
for item in inputs:
|
||||
item_num = size(item)
|
||||
num += item_num
|
||||
inputs_shape += ((item.shape, num),)
|
||||
items_num += (item_num,)
|
||||
cum_num += (num,)
|
||||
for i, element in enumerate(inputs):
|
||||
item_size = items_num[i]
|
||||
if i in grad_position:
|
||||
temp2 = Tensor(np.eye(num, item_size, -cum_num[i], np.float32))
|
||||
else:
|
||||
temp2 = zeros((num, item_size), mstype.float32)
|
||||
input_v = reshape(temp2, (num,) + element.shape)
|
||||
primal = broadcast_to(element, (num,) + element.shape)
|
||||
v += (input_v,)
|
||||
primals += (primal,)
|
||||
if len(inputs) == 1:
|
||||
return primals, v[0], inputs_shape
|
||||
return primals, v, inputs_shape
|
||||
|
||||
|
||||
_vmap = _Vmap()
|
||||
|
||||
|
||||
def jacfwd(fn, grad_position=0, has_aux=False):
|
||||
"""
|
||||
Compute Jacobian via forward mode, corresponding to
|
||||
`forward-mode differentiation <https://www.mindspore.cn/docs/en/master/design/auto_gradient.html#forward-mode-ad>`_.
|
||||
When number of outputs is much greater than that of inputs, it's better to calculate Jacobian via forward mode than
|
||||
reverse mode to get better performance.
|
||||
|
||||
Args:
|
||||
fn (Union(Cell, function)): Function to do GradOperation.
|
||||
grad_position (Union(int, tuple[int])): If int, get the gradient with respect to single input.
|
||||
If tuple, get the gradients with respect to selected inputs. 'grad_position' begins with 0. Default: 0.
|
||||
has_aux (bool): If True, only the first output of `fn` contributes the gradient of `fn`, while the other outputs
|
||||
will be returned straightly. It means the `fn` must return more than one outputs in this case.
|
||||
Default: False.
|
||||
|
||||
Returns:
|
||||
Function, returns the Jacobian function for the input function or cell.
|
||||
For example, as for `out1, out2 = fn(*args)`, when `has_aux` is set True, gradient function will return outputs
|
||||
like `(Jacobian, out2)` and `out2` does not contribute to the differentiation, otherwise `Jacobian` .
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> import mindspore.nn as nn
|
||||
>>> from mindspore.ops import jacfwd
|
||||
>>> from mindspore import Tensor
|
||||
>>> class MultipleInputsMultipleOutputsNet(nn.Cell):
|
||||
... def construct(self, x, y, z):
|
||||
... return x ** 2 + y ** 2 + z ** 2, x * y * z
|
||||
>>> x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
>>> y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
>>> z = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||
>>> net = MultipleInputsMultipleOutputsNet()
|
||||
>>> jac, aux = jacfwd(net, grad_position=0, has_aux=True)(x, y, z)
|
||||
>>> print(jac)
|
||||
Tensor(shape=[2, 2, 2, 2], dtype=Float32, value=
|
||||
[[[[ 2.00000000e+00, 0.00000000e+00],
|
||||
[ 0.00000000e+00, 0.00000000e+00]],
|
||||
[[ 0.00000000e+00, 4.00000000e+00],
|
||||
[ 0.00000000e+00, 0.00000000e+00]]],
|
||||
[[[ 0.00000000e+00, 0.00000000e+00],
|
||||
[ 6.00000000e+00, 0.00000000e+00]],
|
||||
[[ 0.00000000e+00, 0.00000000e+00],
|
||||
[ 0.00000000e+00, 8.00000000e+00]]]])
|
||||
>>> print(aux)
|
||||
[[ 1. 4.]
|
||||
[ 9. 16.]]
|
||||
"""
|
||||
_check_has_aux_type(has_aux)
|
||||
|
||||
def aux_fn(*args):
|
||||
outputs = fn(*args)
|
||||
if not isinstance(outputs, tuple) or len(outputs) < 2:
|
||||
raise ValueError("When 'has_aux' is True, origin 'fn' requires more than one outputs.")
|
||||
res = outputs[0]
|
||||
return res
|
||||
|
||||
def grad_single(u, first_grad_single_value):
|
||||
if has_aux:
|
||||
return _grad_single(aux_fn)(*first_grad_single_value, u)
|
||||
return _grad_single(fn)(*first_grad_single_value, u)
|
||||
|
||||
def grad_all(u, first_grad):
|
||||
if has_aux:
|
||||
return _grad_all(aux_fn)(*first_grad, u)
|
||||
return _grad_all(fn)(*first_grad, u)
|
||||
|
||||
@ms_function
|
||||
def wrapped(*args):
|
||||
checked_grad_position = _check_grad_position(grad_position, len(args))
|
||||
primals, v, inputs_shape = _jacfwd_construct_v(args, checked_grad_position)
|
||||
|
||||
def inner_fn(jvp_inputs, vectors):
|
||||
outputs = fn(*jvp_inputs)
|
||||
if isinstance(outputs, tuple):
|
||||
u = ()
|
||||
for item in outputs:
|
||||
u = u + (oneslike(item),)
|
||||
else:
|
||||
u = oneslike(outputs)
|
||||
if len(jvp_inputs) == 1:
|
||||
second_grad_net = _grad_single(grad_single)
|
||||
else:
|
||||
second_grad_net = _grad_single(grad_all)
|
||||
gradient_outputs = second_grad_net(u, jvp_inputs, vectors)
|
||||
return gradient_outputs
|
||||
|
||||
def inner_aux_fn(jvp_inputs, vectors):
|
||||
outputs = aux_fn(*jvp_inputs)
|
||||
u = oneslike(outputs)
|
||||
if len(jvp_inputs) == 1:
|
||||
second_grad_net = _grad_single(grad_single)
|
||||
else:
|
||||
second_grad_net = _grad_single(grad_all)
|
||||
gradient_outputs = second_grad_net(u, jvp_inputs, vectors)
|
||||
return gradient_outputs
|
||||
|
||||
if has_aux:
|
||||
res = _vmap(inner_aux_fn)(primals, v)
|
||||
jac_res = _jacfwd_postprocess(res, inputs_shape, checked_grad_position)
|
||||
forward_outputs = fn(*args)
|
||||
if len(forward_outputs) == 2:
|
||||
return jac_res, forward_outputs[1]
|
||||
return jac_res, forward_outputs[1:]
|
||||
res = _vmap(inner_fn)(primals, v)
|
||||
jac_res = _jacfwd_postprocess(res, inputs_shape, checked_grad_position)
|
||||
return jac_res
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
__all__ = [
|
||||
'grad',
|
||||
'value_and_grad',
|
||||
|
@ -848,6 +1051,7 @@ __all__ = [
|
|||
'derivative',
|
||||
'jvp',
|
||||
'vjp',
|
||||
'jacfwd',
|
||||
'linearize'
|
||||
]
|
||||
__all__.sort()
|
||||
|
|
|
@ -0,0 +1,244 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""test function jacfwd in graph mode"""
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.nn as nn
|
||||
import mindspore.context as context
|
||||
from mindspore import Tensor
|
||||
from mindspore import ms_function
|
||||
from mindspore.ops import jacfwd
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
class SingleInputSingleOutputNet(nn.Cell):
|
||||
def construct(self, x):
|
||||
return x ** 3
|
||||
|
||||
|
||||
class SingleInputMultipleOutputsNet(nn.Cell):
|
||||
def construct(self, x):
|
||||
return x ** 3, 2 * x
|
||||
|
||||
|
||||
class MultipleInputsSingleOutputNet(nn.Cell):
|
||||
def construct(self, x, y, z):
|
||||
return x * y * z
|
||||
|
||||
|
||||
class MultipleInputsMultipleOutputsNet(nn.Cell):
|
||||
def construct(self, x, y, z):
|
||||
return x ** 2 + y ** 2 + z ** 2, x * y * z
|
||||
|
||||
|
||||
def function(x, y, z):
|
||||
return x ** 2 + y ** 2 + z ** 2, x * y * z
|
||||
|
||||
|
||||
def iteration_jac_function(x, y, z):
|
||||
return x ** 2 * y * z
|
||||
|
||||
|
||||
@ms_function
|
||||
def jac_wrap_with_ms_function(x, y, z):
|
||||
output = jacfwd(function, has_aux=True)(x, y, z)
|
||||
return output
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_jac_single_input_single_output_cell_graph():
|
||||
"""
|
||||
Features: Function jacfwd.
|
||||
Description: Test ops.jacfwd with single input and single output net in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
net = SingleInputSingleOutputNet()
|
||||
expect_jac = np.array([[[[3, 0], [0, 0]], [[0, 12], [0, 0]]],
|
||||
[[[0, 0], [27, 0]], [[0, 0], [0, 48]]]]).astype(np.float32)
|
||||
jac = jacfwd(net)(x)
|
||||
assert np.allclose(jac.asnumpy(), expect_jac)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_jac_single_input_multiple_outputs_cell_graph():
|
||||
"""
|
||||
Features: Function jacfwd.
|
||||
Description: Test ops.jacfwd with single input and multiple outputs net in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
net = SingleInputMultipleOutputsNet()
|
||||
expect_jac_0 = np.array([[[[3, 0], [0, 0]], [[0, 12], [0, 0]]],
|
||||
[[[0, 0], [27, 0]], [[0, 0], [0, 48]]]]).astype(np.float32)
|
||||
expect_jac_1 = np.array([[[[2, 0], [0, 0]], [[0, 2], [0, 0]]],
|
||||
[[[0, 0], [2, 0]], [[0, 0], [0, 2]]]]).astype(np.float32)
|
||||
jac = jacfwd(net)(x)
|
||||
assert np.allclose(jac[0].asnumpy(), expect_jac_0)
|
||||
assert np.allclose(jac[1].asnumpy(), expect_jac_1)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_jac_multiple_inputs_single_output_cell_graph():
|
||||
"""
|
||||
Features: Function jacfwd.
|
||||
Description: Test ops.jacfwd with multiple inputs and single output net in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
y = Tensor(np.array([[-2, 3], [-1, 2]]).astype(np.float32))
|
||||
z = Tensor(np.array([[0, 3], [5, -1]]).astype(np.float32))
|
||||
net = MultipleInputsSingleOutputNet()
|
||||
expect_jac_0 = np.array([[[[0, 0], [0, 0]], [[0, 6], [0, 0]]],
|
||||
[[[0, 0], [15, 0]], [[0, 0], [0, -4]]]]).astype(np.float32)
|
||||
expect_jac_1 = np.array([[[[-2, 0], [0, 0]], [[0, 6], [0, 0]]],
|
||||
[[[0, 0], [-3, 0]], [[0, 0], [0, 8]]]]).astype(np.float32)
|
||||
jac = jacfwd(net, grad_position=(1, 2))(x, y, z)
|
||||
assert isinstance(jac, tuple)
|
||||
assert len(jac) == 2
|
||||
assert np.allclose(jac[0].asnumpy(), expect_jac_0)
|
||||
assert np.allclose(jac[1].asnumpy(), expect_jac_1)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_jac_multiple_inputs_multiple_outputs_cell_graph():
|
||||
"""
|
||||
Features: Function jacfwd.
|
||||
Description: Test ops.jacfwd with multiple inputs and multiple outputs net in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
y = Tensor(np.array([[-2, 3], [-1, 2]]).astype(np.float32))
|
||||
z = Tensor(np.array([[0, 3], [5, -1]]).astype(np.float32))
|
||||
net = MultipleInputsMultipleOutputsNet()
|
||||
expect_jac_0 = np.array([[[[-4, 0], [0, 0]], [[0, 6], [0, 0]]],
|
||||
[[[0, 0], [-2, 0]], [[0, 0], [0, 4]]]]).astype(np.float32)
|
||||
expect_jac_1 = np.array([[[[0, 0], [0, 0]], [[0, 6], [0, 0]]],
|
||||
[[[0, 0], [10, 0]], [[0, 0], [0, -2]]]]).astype(np.float32)
|
||||
expect_jac_2 = np.array([[[[0, 0], [0, 0]], [[0, 6], [0, 0]]],
|
||||
[[[0, 0], [15, 0]], [[0, 0], [0, -4]]]]).astype(np.float32)
|
||||
expect_jac_3 = np.array([[[[-2, 0], [0, 0]], [[0, 6], [0, 0]]],
|
||||
[[[0, 0], [-3, 0]], [[0, 0], [0, 8]]]]).astype(np.float32)
|
||||
jac = jacfwd(net, grad_position=(1, 2))(x, y, z)
|
||||
assert isinstance(jac, tuple)
|
||||
assert len(jac) == 2
|
||||
assert np.allclose(jac[0][0].asnumpy(), expect_jac_0)
|
||||
assert np.allclose(jac[0][1].asnumpy(), expect_jac_1)
|
||||
assert np.allclose(jac[1][0].asnumpy(), expect_jac_2)
|
||||
assert np.allclose(jac[1][1].asnumpy(), expect_jac_3)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_jac_wrap_with_ms_function_graph():
|
||||
"""
|
||||
Features: Function jacfwd.
|
||||
Description: Test ops.jacfwd warpped with ms_function in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
y = Tensor(np.array([[-2, 3], [-1, 2]]).astype(np.float32))
|
||||
z = Tensor(np.array([[0, 3], [5, -1]]).astype(np.float32))
|
||||
expect_jac = np.array([[[[2, 0], [0, 0]], [[0, 4], [0, 0]]],
|
||||
[[[0, 0], [6, 0]], [[0, 0], [0, 8]]]]).astype(np.float32)
|
||||
expect_aux = np.array([[0, 18], [-15, -8]]).astype(np.float32)
|
||||
jac, aux = jac_wrap_with_ms_function(x, y, z)
|
||||
assert np.allclose(jac.asnumpy(), expect_jac)
|
||||
assert np.allclose(aux.asnumpy(), expect_aux)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_jac_with_grad_position_twice_graph():
|
||||
"""
|
||||
Features: Function jacfwd.
|
||||
Description: Test ops.jacfwd with function setting grad_position twice in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
y = Tensor(np.array([[1, 3], [5, 7]]).astype(np.float32))
|
||||
z = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||
expect_jac_0 = np.array([[[[1, 0], [0, 0]], [[0, 3], [0, 0]]],
|
||||
[[[0, 0], [5, 0]], [[0, 0], [0, 7]]]]).astype(np.float32)
|
||||
expect_jac_1 = np.array([[[[1, 0], [0, 0]], [[0, 2], [0, 0]]],
|
||||
[[[0, 0], [3, 0]], [[0, 0], [0, 4]]]]).astype(np.float32)
|
||||
net = MultipleInputsSingleOutputNet()
|
||||
jac1 = jacfwd(net, grad_position=0)(x, y, z)
|
||||
jac2 = jacfwd(net, grad_position=(0, 1))(x, y, z)
|
||||
|
||||
assert np.allclose(jac1.asnumpy(), expect_jac_0)
|
||||
assert np.allclose(jac2[1].asnumpy(), expect_jac_1)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_jac_with_has_aux_graph():
|
||||
"""
|
||||
Features: Function jacfwd.
|
||||
Description: Test ops.jacfwd with Cell setting grad_position in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
z = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||
expect_jac = np.array([[[[2, 0], [0, 0]], [[0, 4], [0, 0]]],
|
||||
[[[0, 0], [6, 0]], [[0, 0], [0, 8]]]]).astype(np.float32)
|
||||
expect_aux = np.array([[1, 4], [9, 16]]).astype(np.float32)
|
||||
net = MultipleInputsMultipleOutputsNet()
|
||||
jac, aux = jacfwd(net, grad_position=0, has_aux=True)(x, y, z)
|
||||
assert np.allclose(jac.asnumpy(), expect_jac)
|
||||
assert np.allclose(aux.asnumpy(), expect_aux)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_jac_with_function_has_aux_graph():
|
||||
"""
|
||||
Features: Function jacfwd.
|
||||
Description: Test ops.jacfwd with function setting grad_position in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
def fn(x, y, z):
|
||||
return x ** 2 + y ** 2 + z ** 2, x * y * z
|
||||
|
||||
def fn2(*args):
|
||||
x = args[0]
|
||||
y = args[1]
|
||||
z = args[2]
|
||||
return fn(x, y, z)
|
||||
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
z = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||
expect_jac = np.array([[[[2, 0], [0, 0]], [[0, 4], [0, 0]]],
|
||||
[[[0, 0], [6, 0]], [[0, 0], [0, 8]]]]).astype(np.float32)
|
||||
expect_aux = np.array([[1, 4], [9, 16]]).astype(np.float32)
|
||||
jac, aux = jacfwd(fn2, grad_position=0, has_aux=True)(x, y, z)
|
||||
assert np.allclose(jac.asnumpy(), expect_jac)
|
||||
assert np.allclose(aux.asnumpy(), expect_aux)
|
|
@ -0,0 +1,244 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""test function jacfwd in pynative mode"""
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.nn as nn
|
||||
import mindspore.context as context
|
||||
from mindspore import Tensor
|
||||
from mindspore import ms_function
|
||||
from mindspore.ops import jacfwd
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
|
||||
|
||||
class SingleInputSingleOutputNet(nn.Cell):
|
||||
def construct(self, x):
|
||||
return x ** 3
|
||||
|
||||
|
||||
class SingleInputMultipleOutputsNet(nn.Cell):
|
||||
def construct(self, x):
|
||||
return x ** 3, 2 * x
|
||||
|
||||
|
||||
class MultipleInputsSingleOutputNet(nn.Cell):
|
||||
def construct(self, x, y, z):
|
||||
return x * y * z
|
||||
|
||||
|
||||
class MultipleInputsMultipleOutputsNet(nn.Cell):
|
||||
def construct(self, x, y, z):
|
||||
return x ** 2 + y ** 2 + z ** 2, x * y * z
|
||||
|
||||
|
||||
def function(x, y, z):
|
||||
return x ** 2 + y ** 2 + z ** 2, x * y * z
|
||||
|
||||
|
||||
def iteration_jac_function(x, y, z):
|
||||
return x ** 2 * y * z
|
||||
|
||||
|
||||
@ms_function
|
||||
def jac_wrap_with_ms_function(x, y, z):
|
||||
output = jacfwd(function, has_aux=True)(x, y, z)
|
||||
return output
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_jac_single_input_single_output_cell_pynative():
|
||||
"""
|
||||
Features: Function jacfwd.
|
||||
Description: Test ops.jacfwd with single input and single output net in pynative mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
net = SingleInputSingleOutputNet()
|
||||
expect_jac = np.array([[[[3, 0], [0, 0]], [[0, 12], [0, 0]]],
|
||||
[[[0, 0], [27, 0]], [[0, 0], [0, 48]]]]).astype(np.float32)
|
||||
jac = jacfwd(net)(x)
|
||||
assert np.allclose(jac.asnumpy(), expect_jac)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_jac_single_input_multiple_outputs_cell_pynative():
|
||||
"""
|
||||
Features: Function jacfwd.
|
||||
Description: Test ops.jacfwd with single input and multiple outputs net in pynative mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
net = SingleInputMultipleOutputsNet()
|
||||
expect_jac_0 = np.array([[[[3, 0], [0, 0]], [[0, 12], [0, 0]]],
|
||||
[[[0, 0], [27, 0]], [[0, 0], [0, 48]]]]).astype(np.float32)
|
||||
expect_jac_1 = np.array([[[[2, 0], [0, 0]], [[0, 2], [0, 0]]],
|
||||
[[[0, 0], [2, 0]], [[0, 0], [0, 2]]]]).astype(np.float32)
|
||||
jac = jacfwd(net)(x)
|
||||
assert np.allclose(jac[0].asnumpy(), expect_jac_0)
|
||||
assert np.allclose(jac[1].asnumpy(), expect_jac_1)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_jac_multiple_inputs_single_output_cell_pynative():
|
||||
"""
|
||||
Features: Function jacfwd.
|
||||
Description: Test ops.jacfwd with multiple inputs and single output net in pynative mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
y = Tensor(np.array([[-2, 3], [-1, 2]]).astype(np.float32))
|
||||
z = Tensor(np.array([[0, 3], [5, -1]]).astype(np.float32))
|
||||
net = MultipleInputsSingleOutputNet()
|
||||
expect_jac_0 = np.array([[[[0, 0], [0, 0]], [[0, 6], [0, 0]]],
|
||||
[[[0, 0], [15, 0]], [[0, 0], [0, -4]]]]).astype(np.float32)
|
||||
expect_jac_1 = np.array([[[[-2, 0], [0, 0]], [[0, 6], [0, 0]]],
|
||||
[[[0, 0], [-3, 0]], [[0, 0], [0, 8]]]]).astype(np.float32)
|
||||
jac = jacfwd(net, grad_position=(1, 2))(x, y, z)
|
||||
assert isinstance(jac, tuple)
|
||||
assert len(jac) == 2
|
||||
assert np.allclose(jac[0].asnumpy(), expect_jac_0)
|
||||
assert np.allclose(jac[1].asnumpy(), expect_jac_1)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_jac_multiple_inputs_multiple_outputs_cell_pynative():
|
||||
"""
|
||||
Features: Function jacfwd.
|
||||
Description: Test ops.jacfwd with multiple inputs and multiple outputs net in pynative mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
y = Tensor(np.array([[-2, 3], [-1, 2]]).astype(np.float32))
|
||||
z = Tensor(np.array([[0, 3], [5, -1]]).astype(np.float32))
|
||||
net = MultipleInputsMultipleOutputsNet()
|
||||
expect_jac_0 = np.array([[[[-4, 0], [0, 0]], [[0, 6], [0, 0]]],
|
||||
[[[0, 0], [-2, 0]], [[0, 0], [0, 4]]]]).astype(np.float32)
|
||||
expect_jac_1 = np.array([[[[0, 0], [0, 0]], [[0, 6], [0, 0]]],
|
||||
[[[0, 0], [10, 0]], [[0, 0], [0, -2]]]]).astype(np.float32)
|
||||
expect_jac_2 = np.array([[[[0, 0], [0, 0]], [[0, 6], [0, 0]]],
|
||||
[[[0, 0], [15, 0]], [[0, 0], [0, -4]]]]).astype(np.float32)
|
||||
expect_jac_3 = np.array([[[[-2, 0], [0, 0]], [[0, 6], [0, 0]]],
|
||||
[[[0, 0], [-3, 0]], [[0, 0], [0, 8]]]]).astype(np.float32)
|
||||
jac = jacfwd(net, grad_position=(1, 2))(x, y, z)
|
||||
assert isinstance(jac, tuple)
|
||||
assert len(jac) == 2
|
||||
assert np.allclose(jac[0][0].asnumpy(), expect_jac_0)
|
||||
assert np.allclose(jac[0][1].asnumpy(), expect_jac_1)
|
||||
assert np.allclose(jac[1][0].asnumpy(), expect_jac_2)
|
||||
assert np.allclose(jac[1][1].asnumpy(), expect_jac_3)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_jac_wrap_with_ms_function_pynative():
|
||||
"""
|
||||
Features: Function jacfwd.
|
||||
Description: Test ops.jacfwd warpped with ms_function in pynative mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
y = Tensor(np.array([[-2, 3], [-1, 2]]).astype(np.float32))
|
||||
z = Tensor(np.array([[0, 3], [5, -1]]).astype(np.float32))
|
||||
expect_jac = np.array([[[[2, 0], [0, 0]], [[0, 4], [0, 0]]],
|
||||
[[[0, 0], [6, 0]], [[0, 0], [0, 8]]]]).astype(np.float32)
|
||||
expect_aux = np.array([[0, 18], [-15, -8]]).astype(np.float32)
|
||||
jac, aux = jac_wrap_with_ms_function(x, y, z)
|
||||
assert np.allclose(jac.asnumpy(), expect_jac)
|
||||
assert np.allclose(aux.asnumpy(), expect_aux)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_jac_with_grad_position_twice_pynative():
|
||||
"""
|
||||
Features: Function jacfwd.
|
||||
Description: Test ops.jacfwd with function setting grad_position twice in pynative mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
y = Tensor(np.array([[1, 3], [5, 7]]).astype(np.float32))
|
||||
z = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||
expect_jac_0 = np.array([[[[1, 0], [0, 0]], [[0, 3], [0, 0]]],
|
||||
[[[0, 0], [5, 0]], [[0, 0], [0, 7]]]]).astype(np.float32)
|
||||
expect_jac_1 = np.array([[[[1, 0], [0, 0]], [[0, 2], [0, 0]]],
|
||||
[[[0, 0], [3, 0]], [[0, 0], [0, 4]]]]).astype(np.float32)
|
||||
net = MultipleInputsSingleOutputNet()
|
||||
jac1 = jacfwd(net, grad_position=0)(x, y, z)
|
||||
jac2 = jacfwd(net, grad_position=(0, 1))(x, y, z)
|
||||
|
||||
assert np.allclose(jac1.asnumpy(), expect_jac_0)
|
||||
assert np.allclose(jac2[1].asnumpy(), expect_jac_1)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_jac_with_has_aux_pynative():
|
||||
"""
|
||||
Features: Function jacfwd.
|
||||
Description: Test ops.jacfwd with Cell setting grad_position in pynative mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
z = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||
expect_jac = np.array([[[[2, 0], [0, 0]], [[0, 4], [0, 0]]],
|
||||
[[[0, 0], [6, 0]], [[0, 0], [0, 8]]]]).astype(np.float32)
|
||||
expect_aux = np.array([[1, 4], [9, 16]]).astype(np.float32)
|
||||
net = MultipleInputsMultipleOutputsNet()
|
||||
jac, aux = jacfwd(net, grad_position=0, has_aux=True)(x, y, z)
|
||||
assert np.allclose(jac.asnumpy(), expect_jac)
|
||||
assert np.allclose(aux.asnumpy(), expect_aux)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_jac_with_function_has_aux_pynative():
|
||||
"""
|
||||
Features: Function jacfwd.
|
||||
Description: Test ops.jacfwd with function setting grad_position in pynative mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
def fn(x, y, z):
|
||||
return x ** 2 + y ** 2 + z ** 2, x * y * z
|
||||
|
||||
def fn2(*args):
|
||||
x = args[0]
|
||||
y = args[1]
|
||||
z = args[2]
|
||||
return fn(x, y, z)
|
||||
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
z = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||
expect_jac = np.array([[[[2, 0], [0, 0]], [[0, 4], [0, 0]]],
|
||||
[[[0, 0], [6, 0]], [[0, 0], [0, 8]]]]).astype(np.float32)
|
||||
expect_aux = np.array([[1, 4], [9, 16]]).astype(np.float32)
|
||||
jac, aux = jacfwd(fn2, grad_position=0, has_aux=True)(x, y, z)
|
||||
assert np.allclose(jac.asnumpy(), expect_jac)
|
||||
assert np.allclose(aux.asnumpy(), expect_aux)
|
Loading…
Reference in New Issue