mindir:add @ms_function testcase

This commit is contained in:
lanzhineng 2021-08-24 16:39:44 +08:00
parent af6d16ec14
commit 6daabf9a86
1 changed files with 27 additions and 3 deletions

View File

@ -16,7 +16,7 @@ import numpy as np
import pytest
import mindspore.nn as nn
from mindspore import context
from mindspore import context, ms_function
from mindspore.common.tensor import Tensor
from mindspore.train.serialization import export, load
@ -52,6 +52,32 @@ def test_single_while():
outputs_after_load = loaded_net(x, y)
assert origin_out == outputs_after_load
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
def test_ms_function_while():
context.set_context(mode=context.PYNATIVE_MODE)
network = SingleWhileNet()
x = Tensor(np.array([1]).astype(np.float32))
y = Tensor(np.array([2]).astype(np.float32))
origin_out = network(x, y)
file_name = "while_net"
export(network, x, y, file_name=file_name, file_format='MINDIR')
mindir_name = file_name + ".mindir"
assert os.path.exists(mindir_name)
graph = load(mindir_name)
loaded_net = nn.GraphCell(graph)
@ms_function
def run_graph(x, y):
outputs = loaded_net(x, y)
return outputs
outputs_after_load = run_graph(x, y)
assert origin_out == outputs_after_load
class SingleWhileInlineNet(nn.Cell):
def construct(self, x, y):
@ -96,8 +122,6 @@ def test_single_while_inline_load():
assert os.path.exists(mindir_name)
load(mindir_name)
@pytest.mark.skip(reason="inline is not supported yet")
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training