From c93ac8c098e8a5693ba7e9beed2bd93fcddd0ec9 Mon Sep 17 00:00:00 2001 From: "7347157+joylvliang@user.noreply.gitee.com" Date: Thu, 9 Sep 2021 20:32:54 +0800 Subject: [PATCH] add_graph_flag_info_for_phase --- mindspore/ccsrc/pipeline/jit/action.cc | 8 ++++---- mindspore/ccsrc/pipeline/pynative/pynative_execute.cc | 3 +++ mindspore/ccsrc/pipeline/pynative/pynative_execute.h | 1 + mindspore/common/api.py | 5 +++++ 4 files changed, 13 insertions(+), 4 deletions(-) diff --git a/mindspore/ccsrc/pipeline/jit/action.cc b/mindspore/ccsrc/pipeline/jit/action.cc index 93c29128d30..40b5a5714d1 100644 --- a/mindspore/ccsrc/pipeline/jit/action.cc +++ b/mindspore/ccsrc/pipeline/jit/action.cc @@ -656,9 +656,10 @@ bool EliminateForwardCNode(const ResourcePtr &res) { MS_EXCEPTION_IF_NULL(graph_executor); auto phase = graph_executor->phase(); MS_LOG(DEBUG) << "The phase of current pipeline graph is: " << phase; - // Export graph run in pynative mode no need to do this action. - if (phase.find("export") != std::string::npos) { - auto pynative_exec = pynative::PynativeExecutor::GetInstance(); + // Exporting graph in PyNative mode or only running forward process no need to do this action. + auto pynative_exec = pynative::PynativeExecutor::GetInstance(); + if (phase.find("export") == 0 || !pynative_exec->grad_flag()) { + MS_LOG(DEBUG) << "When exporting graph or only running forward process, no need to eliminate forward cnode."; auto grad_exec = pynative_exec->grad_executor(); grad_exec->set_eliminate_forward(true); return true; @@ -669,7 +670,6 @@ bool EliminateForwardCNode(const ResourcePtr &res) { MS_EXCEPTION_IF_NULL(res); auto ms_func_graph = res->func_graph(); MS_EXCEPTION_IF_NULL(ms_func_graph); - auto pynative_exec = pynative::PynativeExecutor::GetInstance(); auto grad_exec = pynative_exec->grad_executor(); bool eliminate_forward = grad_exec->eliminate_forward(); grad_exec->set_eliminate_forward(eliminate_forward && ms_func_graph->func_graphs_used().empty()); diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index 815cbcabb66..eb17446c763 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -3006,6 +3006,8 @@ ForwardExecutorPtr PynativeExecutor::forward_executor() const { return forward_executor_; } +bool PynativeExecutor::grad_flag() const { return grad_executor()->grad_flag(); } + void PynativeExecutor::set_grad_flag(bool flag) { grad_executor()->set_grad_flag(flag); } void PynativeExecutor::set_graph_phase(const std::string &graph_phase) { @@ -3140,6 +3142,7 @@ REGISTER_PYBIND_DEFINE(PynativeExecutor_, ([](const py::module *m) { .def("execute_all_task", &PynativeExecutor::ExecuteAllTask, "clear all task") .def("__call__", &PynativeExecutor::Run, "pynative executor run grad graph.") .def("set_graph_phase", &PynativeExecutor::set_graph_phase, "pynative set graph phase") + .def("grad_flag", &PynativeExecutor::grad_flag, "pynative grad flag") .def("set_grad_flag", &PynativeExecutor::set_grad_flag, py::arg("flag") = py::bool_(false), "Executor set grad flag.") .def("set_py_exe_path", &PynativeExecutor::set_py_exe_path, diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h index 196338bb7ae..e5c1e8179e1 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h @@ -366,6 +366,7 @@ class PynativeExecutor : public std::enable_shared_from_this { GradExecutorPtr grad_executor() const; ForwardExecutorPtr forward_executor() const; + bool grad_flag() const; void set_grad_flag(bool flag); void set_graph_phase(const std::string &graph_phase); void set_py_exe_path(const py::object &py_exe_path); diff --git a/mindspore/common/api.py b/mindspore/common/api.py index f5cfcee2b44..0053e3fa822 100644 --- a/mindspore/common/api.py +++ b/mindspore/common/api.py @@ -157,6 +157,8 @@ class _MindsporeFunctionExecutor: dic = dict(zip(arg_names, args_list)) generate_name = self.fn.__module__ + "." + self.fn.__name__ + "." + self.fn.__code__.co_filename + "." + \ str(self.fn.__code__.co_firstlineno) + '.' + str(id(self.fn)) + if _pynative_executor.grad_flag(): + generate_name = generate_name + ".grad" self.fn.__parse_method__ = method_name # Add key with obj @@ -398,6 +400,9 @@ class _PynativeExecutor: def set_graph_phase(self, phase): self._executor.set_graph_phase(phase) + def grad_flag(self): + return self._executor.grad_flag() + def set_grad_flag(self, flag): self._executor.set_grad_flag(flag)