convert pynative mode to graph mode when loading mindir

This commit is contained in:
lianliguang 2021-12-24 14:23:49 +08:00
parent 91434597c3
commit d5efef0ae6
6 changed files with 16 additions and 17 deletions

View File

@ -1242,6 +1242,12 @@ std::vector<ActionItem> BackendPipeline() {
return actions;
}
std::vector<ActionItem> MindIRPipeline() {
auto context_ptr = MsContext::GetInstance();
if (context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
MS_LOG(EXCEPTION)
<< "The graph generated form MindIR is not support to execute in the PynativeMode, please convert "
"to the GraphMode.";
}
std::vector<ActionItem> actions;
// Set funcGraph loaded from MindIR to resource.
(void)actions.emplace_back(std::make_pair("load_mindir", SetMindIRGraphAction));

View File

@ -972,6 +972,8 @@ def _spilt_save(net_dict, model, file_name, is_encrypt, **kwargs):
def _save_mindir(net, file_name, *inputs, **kwargs):
"""Save MindIR format file."""
if context._get_mode() == context.PYNATIVE_MODE:
raise RuntimeError("MindIR export is not support in the Pynative mode, please convert to the Graph Mode.")
model = mindir_model()
phase_name = "predict" if net._auto_parallel_mode else "export.mindir"

View File

@ -52,12 +52,13 @@ def test_single_while():
outputs_after_load = loaded_net(x, y)
assert origin_out == outputs_after_load
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
def test_ms_function_while():
context.set_context(mode=context.PYNATIVE_MODE)
context.set_context(mode=context.GRAPH_MODE)
network = SingleWhileNet()
x = Tensor(np.array([1]).astype(np.float32))
@ -71,10 +72,13 @@ def test_ms_function_while():
graph = load(mindir_name)
loaded_net = nn.GraphCell(graph)
context.set_context(mode=context.PYNATIVE_MODE)
@ms_function
def run_graph(x, y):
outputs = loaded_net(x, y)
return outputs
outputs_after_load = run_graph(x, y)
assert origin_out == outputs_after_load
@ -122,6 +126,7 @@ def test_single_while_inline_load():
assert os.path.exists(mindir_name)
load(mindir_name)
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training

View File

@ -196,6 +196,7 @@ class ForwardBGCF(nn.Cell):
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
def test_export_bgcf():
context.set_context(mode=context.GRAPH_MODE)
num_user, num_item = 7068, 3570
network = BGCF([64, num_user, num_item], 64, "tanh",
[0.0, 0.0, 0.0], num_user, num_item, 64)

View File

@ -94,18 +94,3 @@ def test_get_and_init_graph_cell_parameters_in_graph_mode():
"""
context.set_context(mode=context.GRAPH_MODE)
get_and_init_graph_cell_parameters()
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_get_and_init_graph_cell_parameters_in_pynative_mode():
"""
Description: load mind ir and update parameters in pynative mode.
Expectation: generate a graph with updated parameters.
"""
context.set_context(mode=context.PYNATIVE_MODE)
get_and_init_graph_cell_parameters()

View File

@ -99,7 +99,7 @@ def test_init_graph_cell_parameters_with_wrong_value_shape():
Description: load mind ir and update parameters with wrong tensor shape.
Expectation: raise a ValueError indicating the update value shape error.
"""
context.set_context(mode=context.PYNATIVE_MODE)
context.set_context(mode=context.GRAPH_MODE)
net = Net()
mindir_name = "net_2.mindir"
export(net, input_a, input_b, file_name=mindir_name[:-7], file_format='MINDIR')