forked from mindspore-Ecosystem/mindspore
mindir:add @ms_function testcase
This commit is contained in:
parent
af6d16ec14
commit
6daabf9a86
|
@ -16,7 +16,7 @@ import numpy as np
|
|||
import pytest
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
from mindspore import context, ms_function
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.train.serialization import export, load
|
||||
|
||||
|
@ -52,6 +52,32 @@ def test_single_while():
|
|||
outputs_after_load = loaded_net(x, y)
|
||||
assert origin_out == outputs_after_load
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_ms_function_while():
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
network = SingleWhileNet()
|
||||
|
||||
x = Tensor(np.array([1]).astype(np.float32))
|
||||
y = Tensor(np.array([2]).astype(np.float32))
|
||||
origin_out = network(x, y)
|
||||
|
||||
file_name = "while_net"
|
||||
export(network, x, y, file_name=file_name, file_format='MINDIR')
|
||||
mindir_name = file_name + ".mindir"
|
||||
assert os.path.exists(mindir_name)
|
||||
|
||||
graph = load(mindir_name)
|
||||
loaded_net = nn.GraphCell(graph)
|
||||
@ms_function
|
||||
def run_graph(x, y):
|
||||
outputs = loaded_net(x, y)
|
||||
return outputs
|
||||
outputs_after_load = run_graph(x, y)
|
||||
assert origin_out == outputs_after_load
|
||||
|
||||
|
||||
class SingleWhileInlineNet(nn.Cell):
|
||||
def construct(self, x, y):
|
||||
|
@ -96,8 +122,6 @@ def test_single_while_inline_load():
|
|||
assert os.path.exists(mindir_name)
|
||||
load(mindir_name)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="inline is not supported yet")
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
|
|
Loading…
Reference in New Issue