use find instead of equal to distinguish training graph

This commit is contained in:
chenhaozhe 2020-03-30 20:10:56 +08:00
parent 7210751e4a
commit b61ad0a5a7
1 changed files with 1 additions and 1 deletions

View File

@ -1071,7 +1071,7 @@ bool ExecutorPy::AddDFGraph(const py::dict& init_params, const std::string& phas
}
std::string init_graph = "init_subgraph." + net_id;
std::string checkpoint_name = "save." + net_id;
if (phase == "train") {
if (phase.find("train") != std::string::npos) {
(void)DfGraphManager::GetInstance().AddGraph(phase, convertor.GetComputeGraph(), {{"ge.exec.variable_acc", "1"}});
} else {
(void)DfGraphManager::GetInstance().AddGraph(phase, convertor.GetComputeGraph());