forked from mindspore-Ecosystem/mindspore
convert pynative mode to graph mode when loading mindir
This commit is contained in:
parent
91434597c3
commit
d5efef0ae6
|
@ -1242,6 +1242,12 @@ std::vector<ActionItem> BackendPipeline() {
|
|||
return actions;
|
||||
}
|
||||
std::vector<ActionItem> MindIRPipeline() {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
if (context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
|
||||
MS_LOG(EXCEPTION)
|
||||
<< "The graph generated form MindIR is not support to execute in the PynativeMode, please convert "
|
||||
"to the GraphMode.";
|
||||
}
|
||||
std::vector<ActionItem> actions;
|
||||
// Set funcGraph loaded from MindIR to resource.
|
||||
(void)actions.emplace_back(std::make_pair("load_mindir", SetMindIRGraphAction));
|
||||
|
|
|
@ -972,6 +972,8 @@ def _spilt_save(net_dict, model, file_name, is_encrypt, **kwargs):
|
|||
|
||||
def _save_mindir(net, file_name, *inputs, **kwargs):
|
||||
"""Save MindIR format file."""
|
||||
if context._get_mode() == context.PYNATIVE_MODE:
|
||||
raise RuntimeError("MindIR export is not support in the Pynative mode, please convert to the Graph Mode.")
|
||||
model = mindir_model()
|
||||
|
||||
phase_name = "predict" if net._auto_parallel_mode else "export.mindir"
|
||||
|
|
|
@ -52,12 +52,13 @@ def test_single_while():
|
|||
outputs_after_load = loaded_net(x, y)
|
||||
assert origin_out == outputs_after_load
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_ms_function_while():
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
network = SingleWhileNet()
|
||||
|
||||
x = Tensor(np.array([1]).astype(np.float32))
|
||||
|
@ -71,10 +72,13 @@ def test_ms_function_while():
|
|||
|
||||
graph = load(mindir_name)
|
||||
loaded_net = nn.GraphCell(graph)
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
|
||||
@ms_function
|
||||
def run_graph(x, y):
|
||||
outputs = loaded_net(x, y)
|
||||
return outputs
|
||||
|
||||
outputs_after_load = run_graph(x, y)
|
||||
assert origin_out == outputs_after_load
|
||||
|
||||
|
@ -122,6 +126,7 @@ def test_single_while_inline_load():
|
|||
assert os.path.exists(mindir_name)
|
||||
load(mindir_name)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
|
|
|
@ -196,6 +196,7 @@ class ForwardBGCF(nn.Cell):
|
|||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_export_bgcf():
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
num_user, num_item = 7068, 3570
|
||||
network = BGCF([64, num_user, num_item], 64, "tanh",
|
||||
[0.0, 0.0, 0.0], num_user, num_item, 64)
|
||||
|
|
|
@ -94,18 +94,3 @@ def test_get_and_init_graph_cell_parameters_in_graph_mode():
|
|||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
get_and_init_graph_cell_parameters()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_get_and_init_graph_cell_parameters_in_pynative_mode():
|
||||
"""
|
||||
Description: load mind ir and update parameters in pynative mode.
|
||||
Expectation: generate a graph with updated parameters.
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
get_and_init_graph_cell_parameters()
|
||||
|
|
|
@ -99,7 +99,7 @@ def test_init_graph_cell_parameters_with_wrong_value_shape():
|
|||
Description: load mind ir and update parameters with wrong tensor shape.
|
||||
Expectation: raise a ValueError indicating the update value shape error.
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
net = Net()
|
||||
mindir_name = "net_2.mindir"
|
||||
export(net, input_a, input_b, file_name=mindir_name[:-7], file_format='MINDIR')
|
||||
|
|
Loading…
Reference in New Issue