!40032 Add linearize ms_function test case

Merge pull request !40032 from ZhengXuegui/test_linearize
This commit is contained in:
i-robot 2022-08-16 19:27:02 +00:00 committed by Gitee
commit ba031c7277
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 65 additions and 1 deletions

View File

@ -646,6 +646,7 @@ def linearize(fn, inputs):
[[12. 15.]
[16. 19.]]
"""
linearize_inner = _LinearizeInner()
@ms_function(hash_args=fn)
def _wrap_container(*arg):
@ -656,7 +657,6 @@ def linearize(fn, inputs):
vectors = tuple(vectors)
return linearize_inner(fn, vectors, output, args)
linearize_inner = _LinearizeInner()
if not isinstance(inputs, (Tensor, tuple, list)):
_raise_type_error()
if isinstance(inputs, Tensor):

View File

@ -19,6 +19,7 @@ import pytest
from mindspore import nn
from mindspore import context
from mindspore import Tensor
from mindspore import ms_function
from mindspore.ops.functional import linearize, jvp
context.set_context(mode=context.GRAPH_MODE)
@ -181,3 +182,66 @@ def test_linearize_input_function_single_input_single_output_diverse_v_graph():
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
assert np.allclose(grad_0.asnumpy(), expect_grad_0.asnumpy())
assert np.allclose(grad_1.asnumpy(), expect_grad_1.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_linearize_ms_function_single_input_single_output_diverse_v_graph():
"""
Features: Function linearize
Description: Test linearize with ms_function, single input, single output and diverse v in graph mode.
Expectation: No exception.
"""
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
v_0 = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
v_1 = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
net = SingleInputSingleOutputNet()
@ms_function
def linearize_with_ms_function(inputs, v_0, v_1):
output, jvp_fn = linearize(net, inputs)
grad_0 = jvp_fn(v_0)
grad_1 = jvp_fn(v_1)
return output, grad_0, grad_1
expect_primal, expect_grad_0 = jvp(net, x, v_0)
expect_primal, expect_grad_1 = jvp(net, x, v_1)
primal, grad_0, grad_1 = linearize_with_ms_function(x, v_0, v_1)
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
assert np.allclose(grad_0.asnumpy(), expect_grad_0.asnumpy())
assert np.allclose(grad_1.asnumpy(), expect_grad_1.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_linearize_construct_single_input_single_output_diverse_v_graph():
"""
Features: Function linearize
Description: Test linearize with construct, single input, single output and diverse v in graph mode.
Expectation: No exception.
"""
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
v_0 = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
v_1 = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
net = SingleInputSingleOutputNet()
class Net(nn.Cell):
def __init__(self, network):
super(Net, self).__init__()
self.net = network
def construct(self, inputs, v_0, v_1):
output, jvp_fn = linearize(net, inputs)
grad_0 = jvp_fn(v_0)
grad_1 = jvp_fn(v_1)
return output, grad_0, grad_1
test_net = Net(SingleInputSingleOutputNet())
expect_primal, expect_grad_0 = jvp(net, x, v_0)
expect_primal, expect_grad_1 = jvp(net, x, v_1)
primal, grad_0, grad_1 = test_net(x, v_0, v_1)
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
assert np.allclose(grad_0.asnumpy(), expect_grad_0.asnumpy())
assert np.allclose(grad_1.asnumpy(), expect_grad_1.asnumpy())