forked from mindspore-Ecosystem/mindspore
!40032 Add linearize ms_function test case
Merge pull request !40032 from ZhengXuegui/test_linearize
This commit is contained in:
commit
ba031c7277
|
@ -646,6 +646,7 @@ def linearize(fn, inputs):
|
|||
[[12. 15.]
|
||||
[16. 19.]]
|
||||
"""
|
||||
linearize_inner = _LinearizeInner()
|
||||
|
||||
@ms_function(hash_args=fn)
|
||||
def _wrap_container(*arg):
|
||||
|
@ -656,7 +657,6 @@ def linearize(fn, inputs):
|
|||
vectors = tuple(vectors)
|
||||
return linearize_inner(fn, vectors, output, args)
|
||||
|
||||
linearize_inner = _LinearizeInner()
|
||||
if not isinstance(inputs, (Tensor, tuple, list)):
|
||||
_raise_type_error()
|
||||
if isinstance(inputs, Tensor):
|
||||
|
|
|
@ -19,6 +19,7 @@ import pytest
|
|||
from mindspore import nn
|
||||
from mindspore import context
|
||||
from mindspore import Tensor
|
||||
from mindspore import ms_function
|
||||
from mindspore.ops.functional import linearize, jvp
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
@ -181,3 +182,66 @@ def test_linearize_input_function_single_input_single_output_diverse_v_graph():
|
|||
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
|
||||
assert np.allclose(grad_0.asnumpy(), expect_grad_0.asnumpy())
|
||||
assert np.allclose(grad_1.asnumpy(), expect_grad_1.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_linearize_ms_function_single_input_single_output_diverse_v_graph():
|
||||
"""
|
||||
Features: Function linearize
|
||||
Description: Test linearize with ms_function, single input, single output and diverse v in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
v_0 = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||
v_1 = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
net = SingleInputSingleOutputNet()
|
||||
|
||||
@ms_function
|
||||
def linearize_with_ms_function(inputs, v_0, v_1):
|
||||
output, jvp_fn = linearize(net, inputs)
|
||||
grad_0 = jvp_fn(v_0)
|
||||
grad_1 = jvp_fn(v_1)
|
||||
return output, grad_0, grad_1
|
||||
|
||||
expect_primal, expect_grad_0 = jvp(net, x, v_0)
|
||||
expect_primal, expect_grad_1 = jvp(net, x, v_1)
|
||||
primal, grad_0, grad_1 = linearize_with_ms_function(x, v_0, v_1)
|
||||
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
|
||||
assert np.allclose(grad_0.asnumpy(), expect_grad_0.asnumpy())
|
||||
assert np.allclose(grad_1.asnumpy(), expect_grad_1.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_linearize_construct_single_input_single_output_diverse_v_graph():
|
||||
"""
|
||||
Features: Function linearize
|
||||
Description: Test linearize with construct, single input, single output and diverse v in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
v_0 = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||
v_1 = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
net = SingleInputSingleOutputNet()
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(Net, self).__init__()
|
||||
self.net = network
|
||||
|
||||
def construct(self, inputs, v_0, v_1):
|
||||
output, jvp_fn = linearize(net, inputs)
|
||||
grad_0 = jvp_fn(v_0)
|
||||
grad_1 = jvp_fn(v_1)
|
||||
return output, grad_0, grad_1
|
||||
|
||||
test_net = Net(SingleInputSingleOutputNet())
|
||||
expect_primal, expect_grad_0 = jvp(net, x, v_0)
|
||||
expect_primal, expect_grad_1 = jvp(net, x, v_1)
|
||||
primal, grad_0, grad_1 = test_net(x, v_0, v_1)
|
||||
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
|
||||
assert np.allclose(grad_0.asnumpy(), expect_grad_0.asnumpy())
|
||||
assert np.allclose(grad_1.asnumpy(), expect_grad_1.asnumpy())
|
||||
|
|
Loading…
Reference in New Issue