From f7664312a1422e75a9098d31cea8680215c008c3 Mon Sep 17 00:00:00 2001 From: fary86 Date: Sat, 18 Jul 2020 15:48:54 +0800 Subject: [PATCH] Fix dump geir fail --- mindspore/common/api.py | 8 +++----- mindspore/train/serialization.py | 4 ++-- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/mindspore/common/api.py b/mindspore/common/api.py index 050baf9f79e..3ef21b96263 100644 --- a/mindspore/common/api.py +++ b/mindspore/common/api.py @@ -515,18 +515,16 @@ class _Executor: return None return self._executor.get_func_graph_proto(exec_id, ir_type) - def export(self, net, file_name, file_format='GEIR'): + def export(self, file_name, graph_id): """ Export graph. Args: - net (Cell): MindSpore network file_name (str): File name of model to export - file_format (str): MindSpore currently support 'GEIR' and 'ONNX' format for exported model + graph_id (str): id of graph to be exported """ from .._c_expression import export_graph - phase = 'export' + '.' + self.phase_prefix + '.' + str(net.create_time) - export_graph(file_name, file_format, phase) + export_graph(file_name, 'GEIR', graph_id) def fetch_info_for_quant_export(self, exec_id): """Get graph proto from pipeline.""" diff --git a/mindspore/train/serialization.py b/mindspore/train/serialization.py index 7bd5fdca490..de35981d314 100644 --- a/mindspore/train/serialization.py +++ b/mindspore/train/serialization.py @@ -455,8 +455,8 @@ def export(net, *inputs, file_name, file_format='GEIR'): net.init_parameters_data() if file_format == 'GEIR': phase_name = 'export.geir' - _executor.compile(net, *inputs, phase=phase_name) - _executor.export(net, file_name, file_format) + graph_id, _ = _executor.compile(net, *inputs, phase=phase_name) + _executor.export(file_name, graph_id) elif file_format == 'ONNX': # file_format is 'ONNX' # NOTICE: the pahse name `export_onnx` is used for judging whether is exporting onnx in the compile pipeline, # do not change it to other values.