From 6daabf9a868b16ca254b095ba86d97b43724bf0d Mon Sep 17 00:00:00 2001 From: lanzhineng Date: Tue, 24 Aug 2021 16:39:44 +0800 Subject: [PATCH] mindir:add @ms_function testcase --- tests/st/control/test_while_mindir.py | 30 ++++++++++++++++++++++++--- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/tests/st/control/test_while_mindir.py b/tests/st/control/test_while_mindir.py index b68a3a14477..efaca20c8a5 100644 --- a/tests/st/control/test_while_mindir.py +++ b/tests/st/control/test_while_mindir.py @@ -16,7 +16,7 @@ import numpy as np import pytest import mindspore.nn as nn -from mindspore import context +from mindspore import context, ms_function from mindspore.common.tensor import Tensor from mindspore.train.serialization import export, load @@ -52,6 +52,32 @@ def test_single_while(): outputs_after_load = loaded_net(x, y) assert origin_out == outputs_after_load +@pytest.mark.level0 +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.env_onecard +def test_ms_function_while(): + context.set_context(mode=context.PYNATIVE_MODE) + network = SingleWhileNet() + + x = Tensor(np.array([1]).astype(np.float32)) + y = Tensor(np.array([2]).astype(np.float32)) + origin_out = network(x, y) + + file_name = "while_net" + export(network, x, y, file_name=file_name, file_format='MINDIR') + mindir_name = file_name + ".mindir" + assert os.path.exists(mindir_name) + + graph = load(mindir_name) + loaded_net = nn.GraphCell(graph) + @ms_function + def run_graph(x, y): + outputs = loaded_net(x, y) + return outputs + outputs_after_load = run_graph(x, y) + assert origin_out == outputs_after_load + class SingleWhileInlineNet(nn.Cell): def construct(self, x, y): @@ -96,8 +122,6 @@ def test_single_while_inline_load(): assert os.path.exists(mindir_name) load(mindir_name) - -@pytest.mark.skip(reason="inline is not supported yet") @pytest.mark.level0 @pytest.mark.platform_x86_ascend_training @pytest.mark.platform_arm_ascend_training