forked from mindspore-Ecosystem/mindspore
!607 optimize flow of export onnx model
Merge pull request !607 from fary86/optimize_flow_of_exporting_onnx_model
This commit is contained in:
commit
4a001ece95
|
@ -294,6 +294,30 @@ void ExecutorPy::SaveCompiledGraph(const std::string &phase_s) {
|
|||
MS_LOG(INFO) << "End save compiled func graph!";
|
||||
}
|
||||
|
||||
void ExecutorPy::SaveCompiledGraphToPb(const std::string &phase_s) {
|
||||
#ifdef ENABLE_DUMP_IR
|
||||
// save the graph to file in protobuf format
|
||||
FuncGraphPtr func_graph = info_[phase_s]->resource->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
std::string name_prefix = phase_s.substr(0, phase_s.find("."));
|
||||
std::string pb_filename = std::string("ms_output_") + name_prefix + ".pb";
|
||||
std::string filename = GetFilePathName(pb_filename);
|
||||
|
||||
MS_LOG(INFO) << "Begin saving graph to file <<'" << filename << "' in protobuf formart.";
|
||||
ChangeFileMode(filename, S_IRWXU);
|
||||
std::ofstream ofs(filename);
|
||||
if (!ofs.is_open()) {
|
||||
MS_LOG(ERROR) << "Open file '" << filename << "' failed!";
|
||||
return;
|
||||
}
|
||||
ofs << GetFuncGraphProtoString(func_graph);
|
||||
ofs.close();
|
||||
// set file mode to read only by user
|
||||
ChangeFileMode(filename, S_IRUSR);
|
||||
MS_LOG(INFO) << "End saving graph to file in protobuf format";
|
||||
#endif
|
||||
}
|
||||
|
||||
bool ExecutorPy::ChangeExportGeirUseVmFlag(bool use_vm, const std::string &phase_s) const {
|
||||
std::string phase_prefix = GetPhasePrefix(phase_s);
|
||||
|
||||
|
@ -365,6 +389,8 @@ bool ExecutorPy::CompileInner(const py::object &obj, const py::tuple &args, cons
|
|||
info_[phase_s] = executor_info;
|
||||
pip->Run();
|
||||
|
||||
// save compile graph to file in protobuf format
|
||||
SaveCompiledGraphToPb(phase_s);
|
||||
// save the run graph func to MsPipeLine
|
||||
SaveCompiledGraph(phase_s);
|
||||
|
||||
|
@ -557,20 +583,6 @@ void Pipeline::Run() {
|
|||
std::string user_graph_file = GetFilePathName("ModelDigraph.dot");
|
||||
MS_LOG(DEBUG) << "Save user graph to: " << user_graph_file;
|
||||
draw::DrawUserFuncGraph(user_graph_file, user_graph);
|
||||
|
||||
#ifdef ENABLE_DUMP_IR
|
||||
std::string filename = GetFilePathName("ms_output.pb");
|
||||
ChangeFileMode(filename, S_IRWXU);
|
||||
std::ofstream ofs(filename);
|
||||
if (!ofs.is_open()) {
|
||||
MS_LOG(ERROR) << "Open file '" << filename << "' failed!";
|
||||
return;
|
||||
}
|
||||
ofs << GetFuncGraphProtoString(user_graph);
|
||||
ofs.close();
|
||||
// set file mode to read only by user
|
||||
ChangeFileMode(filename, S_IRUSR);
|
||||
#endif
|
||||
}
|
||||
MS_LOG(INFO) << "End";
|
||||
}
|
||||
|
|
|
@ -70,6 +70,7 @@ class ExecutorPy : public std::enable_shared_from_this<ExecutorPy> {
|
|||
~ExecutorPy();
|
||||
|
||||
void SaveCompiledGraph(const std::string &phase_s);
|
||||
void SaveCompiledGraphToPb(const std::string &phase_s);
|
||||
bool CompileInner(const py::object &obj, const py::tuple &args, const py::object &phase, bool use_vm);
|
||||
bool Compile(const py::object &obj, const py::tuple &args, const py::object &phase, bool use_vm);
|
||||
|
||||
|
|
|
@ -158,7 +158,7 @@ void Profile::Print(void) {
|
|||
std::ostringstream oss;
|
||||
PrintProfile(oss, *ctx_ptr_->time_info_);
|
||||
std::string text = oss.str();
|
||||
// the length of text is too long to use MS_LOGINFO, use printf to print it
|
||||
// here use printf to output profile info, not use MS_LOG(INFO) since when open log, it affects performace
|
||||
(void)printf("%s", text.c_str());
|
||||
(void)fflush(stdout);
|
||||
}
|
||||
|
@ -358,7 +358,7 @@ void MsProfile::Print() {
|
|||
PrintTimeStat(oss, groups[i], prefix);
|
||||
}
|
||||
std::string text = oss.str();
|
||||
// the length of text is too long to use MS_LOGINFO, use printf to print it
|
||||
// here use printf to output profile info, not use MS_LOG(INFO) since when open log, it affects performace
|
||||
(void)printf("\nTime group info:\n%s", text.c_str());
|
||||
(void)fflush(stdout);
|
||||
}
|
||||
|
|
|
@ -328,7 +328,7 @@ class _Executor:
|
|||
raise TypeError('Parameters need OrderedDict type, but got {}'.
|
||||
format(type(params)))
|
||||
|
||||
def compile(self, obj, *args, phase='predict', params=None):
|
||||
def compile(self, obj, *args, phase='predict', params=None, do_convert=True):
|
||||
"""
|
||||
Compiles graph.
|
||||
|
||||
|
@ -337,6 +337,7 @@ class _Executor:
|
|||
args (tuple): Function or cell input arguments.
|
||||
phase (str): The name of compile phase. Default: 'predict'.
|
||||
params (OrderedDict): The parameters dictionary used for init data graph. Default: None.
|
||||
do_convert (bool): When set to True, convert ME graph to GE graph after compiling graph.
|
||||
|
||||
Return:
|
||||
Str, the full phase of the cell.
|
||||
|
@ -368,7 +369,8 @@ class _Executor:
|
|||
|
||||
if graph is None:
|
||||
logger.error("%r graph compile failed.", phase)
|
||||
|
||||
if not do_convert:
|
||||
return phase, True
|
||||
if not enable_debug_runtime or enable_ge:
|
||||
if _get_parallel_mode() in ["auto_parallel", "semi_auto_parallel"]:
|
||||
obj.parameter_layout_dict = self._executor.get_parameter_layout(phase)
|
||||
|
|
|
@ -450,7 +450,7 @@ def export(net, *inputs, file_name, file_format='GEIR'):
|
|||
_executor.export(net, file_name, file_format)
|
||||
elif file_format == 'ONNX': # file_format is 'ONNX'
|
||||
phase_name = 'export_onnx'
|
||||
graph_id, _ = _executor.compile(net, *inputs, phase=phase_name)
|
||||
graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False)
|
||||
onnx_stream = _executor._get_func_graph_proto(graph_id)
|
||||
with open(file_name, 'wb') as f:
|
||||
os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR)
|
||||
|
|
Loading…
Reference in New Issue