modify export air

This commit is contained in:
changzherui 2021-04-07 21:02:27 +08:00
parent 13d9ad0f2a
commit 58287c0d57
2 changed files with 21 additions and 19 deletions

View File

@ -177,19 +177,19 @@ void ConvertObjectToTensors(const py::dict &dict, TensorOrderMap *const tensors)
bool AddDFGraph(const std::map<std::string, ExecutorInfoPtr> &info, const py::dict &init_params,
const std::string &phase, const py::object &broadcast_params) {
FuncGraphPtr anf_graph = info.at(phase)->func_graph;
DfGraphConvertor convertor(anf_graph);
DfGraphConvertor converter(anf_graph);
size_t pos = phase.find('.');
std::string net_id = ((pos == std::string::npos || pos == phase.size() - 1) ? phase : phase.substr(pos + 1));
std::string phase_prefix = phase.substr(0, pos);
if (phase_prefix == "export") {
MS_LOG(INFO) << "Set DfGraphConvertor training : false";
convertor.set_training(false);
converter.set_training(false);
}
TensorOrderMap init_tensors{};
ConvertObjectToTensors(init_params, &init_tensors);
(void)convertor.ConvertAllNode().InitParam(init_tensors).BuildGraph();
(void)converter.ConvertAllNode().InitParam(init_tensors).BuildGraph();
if (!broadcast_params.is_none()) {
if (!py::isinstance<py::dict>(broadcast_params)) {
@ -198,38 +198,38 @@ bool AddDFGraph(const std::map<std::string, ExecutorInfoPtr> &info, const py::di
}
py::dict broadcast = broadcast_params.cast<py::dict>();
if (broadcast.empty()) {
(void)convertor.GenerateBroadcastGraph(init_tensors);
(void)converter.GenerateBroadcastGraph(init_tensors);
} else {
TensorOrderMap broadcast_tensors{};
ConvertObjectToTensors(broadcast, &broadcast_tensors);
(void)convertor.GenerateBroadcastGraph(broadcast_tensors);
(void)converter.GenerateBroadcastGraph(broadcast_tensors);
}
MS_LOG(INFO) << "Generate broadcast graph with params and broadcast_empty is " << broadcast.empty();
}
(void)convertor.GenerateCheckpointGraph();
if (convertor.ErrCode() != 0) {
(void)converter.GenerateCheckpointGraph();
if (converter.ErrCode() != 0) {
DfGraphManager::GetInstance().ClearGraph();
MS_LOG(ERROR) << "Convert df graph failed, err:" << convertor.ErrCode();
MS_LOG(ERROR) << "Convert df graph failed, err:" << converter.ErrCode();
return false;
}
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
convertor.DrawComputeGraph(GetSaveGraphsPathName("ge_graph.dot")); // for debug
convertor.DrawInitGraph(GetSaveGraphsPathName("init_graph.dot")); // for debug
convertor.DrawSaveCheckpointGraph(GetSaveGraphsPathName("save_checkpoint_graph.dot")); // for debug
converter.DrawComputeGraph(GetSaveGraphsPathName("ge_graph.dot")); // for debug
converter.DrawInitGraph(GetSaveGraphsPathName("init_graph.dot")); // for debug
converter.DrawSaveCheckpointGraph(GetSaveGraphsPathName("save_checkpoint_graph.dot")); // for debug
}
std::string init_graph = "init_subgraph." + net_id;
std::string checkpoint_name = "save." + net_id;
if (phase.find("train") != std::string::npos) {
(void)DfGraphManager::GetInstance().AddGraph(phase, convertor.GetComputeGraph(), {{"ge.exec.variable_acc", "1"}});
(void)DfGraphManager::GetInstance().AddGraph(phase, converter.GetComputeGraph(), {{"ge.exec.variable_acc", "1"}});
} else {
(void)DfGraphManager::GetInstance().AddGraph(phase, convertor.GetComputeGraph());
(void)DfGraphManager::GetInstance().AddGraph(phase, converter.GetComputeGraph());
}
(void)DfGraphManager::GetInstance().AddGraph(init_graph, convertor.GetInitGraph());
(void)DfGraphManager::GetInstance().AddGraph(BROADCAST_GRAPH_NAME, convertor.GetBroadcastGraph());
(void)DfGraphManager::GetInstance().AddGraph(init_graph, converter.GetInitGraph());
(void)DfGraphManager::GetInstance().AddGraph(BROADCAST_GRAPH_NAME, converter.GetBroadcastGraph());
Status ret = DfGraphManager::GetInstance().AddGraph(checkpoint_name, convertor.GetSaveCheckpointGraph());
Status ret = DfGraphManager::GetInstance().AddGraph(checkpoint_name, converter.GetSaveCheckpointGraph());
if (ret == Status::SUCCESS) {
DfGraphManager::GetInstance().SetAnfGraph(checkpoint_name, anf_graph);
}
@ -529,8 +529,10 @@ void ExportDFGraph(const std::string &file_name, const std::string &phase) {
return;
}
(void)ge_graph->SaveToFile(file_name);
MS_LOG(DEBUG) << "Export graph end.";
if (ge_graph->SaveToFile(file_name) != 0) {
MS_LOG(EXCEPTION) << "Export air model failed.";
}
MS_LOG(INFO) << "Export air model finish.";
}
} // namespace pipeline
} // namespace mindspore

View File

@ -512,7 +512,7 @@ class _Executor:
if "export" not in phase:
init_phase = "init_subgraph" + "." + str(obj.create_time)
_exec_init_graph(obj, init_phase)
elif not enable_ge and "export" in phase:
elif "export" in phase:
self._build_data_graph(obj, phase)
elif BROADCAST_PHASE not in phase and _get_parameter_broadcast():
_parameter_broadcast(obj, auto_parallel_mode)