forked from mindspore-Ecosystem/mindspore
modify export air
This commit is contained in:
parent
13d9ad0f2a
commit
58287c0d57
|
@ -177,19 +177,19 @@ void ConvertObjectToTensors(const py::dict &dict, TensorOrderMap *const tensors)
|
|||
bool AddDFGraph(const std::map<std::string, ExecutorInfoPtr> &info, const py::dict &init_params,
|
||||
const std::string &phase, const py::object &broadcast_params) {
|
||||
FuncGraphPtr anf_graph = info.at(phase)->func_graph;
|
||||
DfGraphConvertor convertor(anf_graph);
|
||||
DfGraphConvertor converter(anf_graph);
|
||||
|
||||
size_t pos = phase.find('.');
|
||||
std::string net_id = ((pos == std::string::npos || pos == phase.size() - 1) ? phase : phase.substr(pos + 1));
|
||||
std::string phase_prefix = phase.substr(0, pos);
|
||||
if (phase_prefix == "export") {
|
||||
MS_LOG(INFO) << "Set DfGraphConvertor training : false";
|
||||
convertor.set_training(false);
|
||||
converter.set_training(false);
|
||||
}
|
||||
|
||||
TensorOrderMap init_tensors{};
|
||||
ConvertObjectToTensors(init_params, &init_tensors);
|
||||
(void)convertor.ConvertAllNode().InitParam(init_tensors).BuildGraph();
|
||||
(void)converter.ConvertAllNode().InitParam(init_tensors).BuildGraph();
|
||||
|
||||
if (!broadcast_params.is_none()) {
|
||||
if (!py::isinstance<py::dict>(broadcast_params)) {
|
||||
|
@ -198,38 +198,38 @@ bool AddDFGraph(const std::map<std::string, ExecutorInfoPtr> &info, const py::di
|
|||
}
|
||||
py::dict broadcast = broadcast_params.cast<py::dict>();
|
||||
if (broadcast.empty()) {
|
||||
(void)convertor.GenerateBroadcastGraph(init_tensors);
|
||||
(void)converter.GenerateBroadcastGraph(init_tensors);
|
||||
} else {
|
||||
TensorOrderMap broadcast_tensors{};
|
||||
ConvertObjectToTensors(broadcast, &broadcast_tensors);
|
||||
(void)convertor.GenerateBroadcastGraph(broadcast_tensors);
|
||||
(void)converter.GenerateBroadcastGraph(broadcast_tensors);
|
||||
}
|
||||
MS_LOG(INFO) << "Generate broadcast graph with params and broadcast_empty is " << broadcast.empty();
|
||||
}
|
||||
|
||||
(void)convertor.GenerateCheckpointGraph();
|
||||
if (convertor.ErrCode() != 0) {
|
||||
(void)converter.GenerateCheckpointGraph();
|
||||
if (converter.ErrCode() != 0) {
|
||||
DfGraphManager::GetInstance().ClearGraph();
|
||||
MS_LOG(ERROR) << "Convert df graph failed, err:" << convertor.ErrCode();
|
||||
MS_LOG(ERROR) << "Convert df graph failed, err:" << converter.ErrCode();
|
||||
return false;
|
||||
}
|
||||
|
||||
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
|
||||
convertor.DrawComputeGraph(GetSaveGraphsPathName("ge_graph.dot")); // for debug
|
||||
convertor.DrawInitGraph(GetSaveGraphsPathName("init_graph.dot")); // for debug
|
||||
convertor.DrawSaveCheckpointGraph(GetSaveGraphsPathName("save_checkpoint_graph.dot")); // for debug
|
||||
converter.DrawComputeGraph(GetSaveGraphsPathName("ge_graph.dot")); // for debug
|
||||
converter.DrawInitGraph(GetSaveGraphsPathName("init_graph.dot")); // for debug
|
||||
converter.DrawSaveCheckpointGraph(GetSaveGraphsPathName("save_checkpoint_graph.dot")); // for debug
|
||||
}
|
||||
std::string init_graph = "init_subgraph." + net_id;
|
||||
std::string checkpoint_name = "save." + net_id;
|
||||
if (phase.find("train") != std::string::npos) {
|
||||
(void)DfGraphManager::GetInstance().AddGraph(phase, convertor.GetComputeGraph(), {{"ge.exec.variable_acc", "1"}});
|
||||
(void)DfGraphManager::GetInstance().AddGraph(phase, converter.GetComputeGraph(), {{"ge.exec.variable_acc", "1"}});
|
||||
} else {
|
||||
(void)DfGraphManager::GetInstance().AddGraph(phase, convertor.GetComputeGraph());
|
||||
(void)DfGraphManager::GetInstance().AddGraph(phase, converter.GetComputeGraph());
|
||||
}
|
||||
(void)DfGraphManager::GetInstance().AddGraph(init_graph, convertor.GetInitGraph());
|
||||
(void)DfGraphManager::GetInstance().AddGraph(BROADCAST_GRAPH_NAME, convertor.GetBroadcastGraph());
|
||||
(void)DfGraphManager::GetInstance().AddGraph(init_graph, converter.GetInitGraph());
|
||||
(void)DfGraphManager::GetInstance().AddGraph(BROADCAST_GRAPH_NAME, converter.GetBroadcastGraph());
|
||||
|
||||
Status ret = DfGraphManager::GetInstance().AddGraph(checkpoint_name, convertor.GetSaveCheckpointGraph());
|
||||
Status ret = DfGraphManager::GetInstance().AddGraph(checkpoint_name, converter.GetSaveCheckpointGraph());
|
||||
if (ret == Status::SUCCESS) {
|
||||
DfGraphManager::GetInstance().SetAnfGraph(checkpoint_name, anf_graph);
|
||||
}
|
||||
|
@ -529,8 +529,10 @@ void ExportDFGraph(const std::string &file_name, const std::string &phase) {
|
|||
return;
|
||||
}
|
||||
|
||||
(void)ge_graph->SaveToFile(file_name);
|
||||
MS_LOG(DEBUG) << "Export graph end.";
|
||||
if (ge_graph->SaveToFile(file_name) != 0) {
|
||||
MS_LOG(EXCEPTION) << "Export air model failed.";
|
||||
}
|
||||
MS_LOG(INFO) << "Export air model finish.";
|
||||
}
|
||||
} // namespace pipeline
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -512,7 +512,7 @@ class _Executor:
|
|||
if "export" not in phase:
|
||||
init_phase = "init_subgraph" + "." + str(obj.create_time)
|
||||
_exec_init_graph(obj, init_phase)
|
||||
elif not enable_ge and "export" in phase:
|
||||
elif "export" in phase:
|
||||
self._build_data_graph(obj, phase)
|
||||
elif BROADCAST_PHASE not in phase and _get_parameter_broadcast():
|
||||
_parameter_broadcast(obj, auto_parallel_mode)
|
||||
|
|
Loading…
Reference in New Issue