diff --git a/mindspore/ccsrc/pipeline/pipeline.cc b/mindspore/ccsrc/pipeline/pipeline.cc index 0c2edfc9c3f..7b3cf06b5eb 100644 --- a/mindspore/ccsrc/pipeline/pipeline.cc +++ b/mindspore/ccsrc/pipeline/pipeline.cc @@ -1104,7 +1104,7 @@ bool ExecutorPy::AddDFGraph(const py::dict& init_params, const std::string& phas } std::string init_graph = "init_subgraph." + net_id; std::string checkpoint_name = "save." + net_id; - if (phase == "train") { + if (phase.find("train") != std::string::npos) { (void)DfGraphManager::GetInstance().AddGraph(phase, convertor.GetComputeGraph(), {{"ge.exec.variable_acc", "1"}}); } else { (void)DfGraphManager::GetInstance().AddGraph(phase, convertor.GetComputeGraph());