forked from mindspore-Ecosystem/mindspore
!607 optimize flow of export onnx model
Merge pull request !607 from fary86/optimize_flow_of_exporting_onnx_model
This commit is contained in:
commit
4a001ece95
|
@ -294,6 +294,30 @@ void ExecutorPy::SaveCompiledGraph(const std::string &phase_s) {
|
||||||
MS_LOG(INFO) << "End save compiled func graph!";
|
MS_LOG(INFO) << "End save compiled func graph!";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ExecutorPy::SaveCompiledGraphToPb(const std::string &phase_s) {
|
||||||
|
#ifdef ENABLE_DUMP_IR
|
||||||
|
// save the graph to file in protobuf format
|
||||||
|
FuncGraphPtr func_graph = info_[phase_s]->resource->func_graph();
|
||||||
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
|
std::string name_prefix = phase_s.substr(0, phase_s.find("."));
|
||||||
|
std::string pb_filename = std::string("ms_output_") + name_prefix + ".pb";
|
||||||
|
std::string filename = GetFilePathName(pb_filename);
|
||||||
|
|
||||||
|
MS_LOG(INFO) << "Begin saving graph to file <<'" << filename << "' in protobuf formart.";
|
||||||
|
ChangeFileMode(filename, S_IRWXU);
|
||||||
|
std::ofstream ofs(filename);
|
||||||
|
if (!ofs.is_open()) {
|
||||||
|
MS_LOG(ERROR) << "Open file '" << filename << "' failed!";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
ofs << GetFuncGraphProtoString(func_graph);
|
||||||
|
ofs.close();
|
||||||
|
// set file mode to read only by user
|
||||||
|
ChangeFileMode(filename, S_IRUSR);
|
||||||
|
MS_LOG(INFO) << "End saving graph to file in protobuf format";
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
bool ExecutorPy::ChangeExportGeirUseVmFlag(bool use_vm, const std::string &phase_s) const {
|
bool ExecutorPy::ChangeExportGeirUseVmFlag(bool use_vm, const std::string &phase_s) const {
|
||||||
std::string phase_prefix = GetPhasePrefix(phase_s);
|
std::string phase_prefix = GetPhasePrefix(phase_s);
|
||||||
|
|
||||||
|
@ -365,6 +389,8 @@ bool ExecutorPy::CompileInner(const py::object &obj, const py::tuple &args, cons
|
||||||
info_[phase_s] = executor_info;
|
info_[phase_s] = executor_info;
|
||||||
pip->Run();
|
pip->Run();
|
||||||
|
|
||||||
|
// save compile graph to file in protobuf format
|
||||||
|
SaveCompiledGraphToPb(phase_s);
|
||||||
// save the run graph func to MsPipeLine
|
// save the run graph func to MsPipeLine
|
||||||
SaveCompiledGraph(phase_s);
|
SaveCompiledGraph(phase_s);
|
||||||
|
|
||||||
|
@ -557,20 +583,6 @@ void Pipeline::Run() {
|
||||||
std::string user_graph_file = GetFilePathName("ModelDigraph.dot");
|
std::string user_graph_file = GetFilePathName("ModelDigraph.dot");
|
||||||
MS_LOG(DEBUG) << "Save user graph to: " << user_graph_file;
|
MS_LOG(DEBUG) << "Save user graph to: " << user_graph_file;
|
||||||
draw::DrawUserFuncGraph(user_graph_file, user_graph);
|
draw::DrawUserFuncGraph(user_graph_file, user_graph);
|
||||||
|
|
||||||
#ifdef ENABLE_DUMP_IR
|
|
||||||
std::string filename = GetFilePathName("ms_output.pb");
|
|
||||||
ChangeFileMode(filename, S_IRWXU);
|
|
||||||
std::ofstream ofs(filename);
|
|
||||||
if (!ofs.is_open()) {
|
|
||||||
MS_LOG(ERROR) << "Open file '" << filename << "' failed!";
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
ofs << GetFuncGraphProtoString(user_graph);
|
|
||||||
ofs.close();
|
|
||||||
// set file mode to read only by user
|
|
||||||
ChangeFileMode(filename, S_IRUSR);
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
MS_LOG(INFO) << "End";
|
MS_LOG(INFO) << "End";
|
||||||
}
|
}
|
||||||
|
|
|
@ -70,6 +70,7 @@ class ExecutorPy : public std::enable_shared_from_this<ExecutorPy> {
|
||||||
~ExecutorPy();
|
~ExecutorPy();
|
||||||
|
|
||||||
void SaveCompiledGraph(const std::string &phase_s);
|
void SaveCompiledGraph(const std::string &phase_s);
|
||||||
|
void SaveCompiledGraphToPb(const std::string &phase_s);
|
||||||
bool CompileInner(const py::object &obj, const py::tuple &args, const py::object &phase, bool use_vm);
|
bool CompileInner(const py::object &obj, const py::tuple &args, const py::object &phase, bool use_vm);
|
||||||
bool Compile(const py::object &obj, const py::tuple &args, const py::object &phase, bool use_vm);
|
bool Compile(const py::object &obj, const py::tuple &args, const py::object &phase, bool use_vm);
|
||||||
|
|
||||||
|
|
|
@ -158,7 +158,7 @@ void Profile::Print(void) {
|
||||||
std::ostringstream oss;
|
std::ostringstream oss;
|
||||||
PrintProfile(oss, *ctx_ptr_->time_info_);
|
PrintProfile(oss, *ctx_ptr_->time_info_);
|
||||||
std::string text = oss.str();
|
std::string text = oss.str();
|
||||||
// the length of text is too long to use MS_LOGINFO, use printf to print it
|
// here use printf to output profile info, not use MS_LOG(INFO) since when open log, it affects performace
|
||||||
(void)printf("%s", text.c_str());
|
(void)printf("%s", text.c_str());
|
||||||
(void)fflush(stdout);
|
(void)fflush(stdout);
|
||||||
}
|
}
|
||||||
|
@ -358,7 +358,7 @@ void MsProfile::Print() {
|
||||||
PrintTimeStat(oss, groups[i], prefix);
|
PrintTimeStat(oss, groups[i], prefix);
|
||||||
}
|
}
|
||||||
std::string text = oss.str();
|
std::string text = oss.str();
|
||||||
// the length of text is too long to use MS_LOGINFO, use printf to print it
|
// here use printf to output profile info, not use MS_LOG(INFO) since when open log, it affects performace
|
||||||
(void)printf("\nTime group info:\n%s", text.c_str());
|
(void)printf("\nTime group info:\n%s", text.c_str());
|
||||||
(void)fflush(stdout);
|
(void)fflush(stdout);
|
||||||
}
|
}
|
||||||
|
|
|
@ -328,7 +328,7 @@ class _Executor:
|
||||||
raise TypeError('Parameters need OrderedDict type, but got {}'.
|
raise TypeError('Parameters need OrderedDict type, but got {}'.
|
||||||
format(type(params)))
|
format(type(params)))
|
||||||
|
|
||||||
def compile(self, obj, *args, phase='predict', params=None):
|
def compile(self, obj, *args, phase='predict', params=None, do_convert=True):
|
||||||
"""
|
"""
|
||||||
Compiles graph.
|
Compiles graph.
|
||||||
|
|
||||||
|
@ -337,6 +337,7 @@ class _Executor:
|
||||||
args (tuple): Function or cell input arguments.
|
args (tuple): Function or cell input arguments.
|
||||||
phase (str): The name of compile phase. Default: 'predict'.
|
phase (str): The name of compile phase. Default: 'predict'.
|
||||||
params (OrderedDict): The parameters dictionary used for init data graph. Default: None.
|
params (OrderedDict): The parameters dictionary used for init data graph. Default: None.
|
||||||
|
do_convert (bool): When set to True, convert ME graph to GE graph after compiling graph.
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
Str, the full phase of the cell.
|
Str, the full phase of the cell.
|
||||||
|
@ -368,7 +369,8 @@ class _Executor:
|
||||||
|
|
||||||
if graph is None:
|
if graph is None:
|
||||||
logger.error("%r graph compile failed.", phase)
|
logger.error("%r graph compile failed.", phase)
|
||||||
|
if not do_convert:
|
||||||
|
return phase, True
|
||||||
if not enable_debug_runtime or enable_ge:
|
if not enable_debug_runtime or enable_ge:
|
||||||
if _get_parallel_mode() in ["auto_parallel", "semi_auto_parallel"]:
|
if _get_parallel_mode() in ["auto_parallel", "semi_auto_parallel"]:
|
||||||
obj.parameter_layout_dict = self._executor.get_parameter_layout(phase)
|
obj.parameter_layout_dict = self._executor.get_parameter_layout(phase)
|
||||||
|
|
|
@ -450,7 +450,7 @@ def export(net, *inputs, file_name, file_format='GEIR'):
|
||||||
_executor.export(net, file_name, file_format)
|
_executor.export(net, file_name, file_format)
|
||||||
elif file_format == 'ONNX': # file_format is 'ONNX'
|
elif file_format == 'ONNX': # file_format is 'ONNX'
|
||||||
phase_name = 'export_onnx'
|
phase_name = 'export_onnx'
|
||||||
graph_id, _ = _executor.compile(net, *inputs, phase=phase_name)
|
graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False)
|
||||||
onnx_stream = _executor._get_func_graph_proto(graph_id)
|
onnx_stream = _executor._get_func_graph_proto(graph_id)
|
||||||
with open(file_name, 'wb') as f:
|
with open(file_name, 'wb') as f:
|
||||||
os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR)
|
os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR)
|
||||||
|
|
Loading…
Reference in New Issue