!23192 Add grad flag info for ms_function phase

Merge pull request !23192 from JoyLvliang/add_graph_flag_info_for_phase
This commit is contained in:
i-robot 2021-09-12 01:54:55 +00:00 committed by Gitee
commit dbfcbae190
4 changed files with 13 additions and 4 deletions

View File

@ -656,9 +656,10 @@ bool EliminateForwardCNode(const ResourcePtr &res) {
MS_EXCEPTION_IF_NULL(graph_executor);
auto phase = graph_executor->phase();
MS_LOG(DEBUG) << "The phase of current pipeline graph is: " << phase;
// Export graph run in pynative mode no need to do this action.
if (phase.find("export") != std::string::npos) {
// Exporting graph in PyNative mode or only running forward process no need to do this action.
auto pynative_exec = pynative::PynativeExecutor::GetInstance();
if (phase.find("export") == 0 || !pynative_exec->grad_flag()) {
MS_LOG(DEBUG) << "When exporting graph or only running forward process, no need to eliminate forward cnode.";
auto grad_exec = pynative_exec->grad_executor();
grad_exec->set_eliminate_forward(true);
return true;
@ -669,7 +670,6 @@ bool EliminateForwardCNode(const ResourcePtr &res) {
MS_EXCEPTION_IF_NULL(res);
auto ms_func_graph = res->func_graph();
MS_EXCEPTION_IF_NULL(ms_func_graph);
auto pynative_exec = pynative::PynativeExecutor::GetInstance();
auto grad_exec = pynative_exec->grad_executor();
bool eliminate_forward = grad_exec->eliminate_forward();
grad_exec->set_eliminate_forward(eliminate_forward && ms_func_graph->func_graphs_used().empty());

View File

@ -3006,6 +3006,8 @@ ForwardExecutorPtr PynativeExecutor::forward_executor() const {
return forward_executor_;
}
bool PynativeExecutor::grad_flag() const { return grad_executor()->grad_flag(); }
void PynativeExecutor::set_grad_flag(bool flag) { grad_executor()->set_grad_flag(flag); }
void PynativeExecutor::set_graph_phase(const std::string &graph_phase) {
@ -3140,6 +3142,7 @@ REGISTER_PYBIND_DEFINE(PynativeExecutor_, ([](const py::module *m) {
.def("execute_all_task", &PynativeExecutor::ExecuteAllTask, "clear all task")
.def("__call__", &PynativeExecutor::Run, "pynative executor run grad graph.")
.def("set_graph_phase", &PynativeExecutor::set_graph_phase, "pynative set graph phase")
.def("grad_flag", &PynativeExecutor::grad_flag, "pynative grad flag")
.def("set_grad_flag", &PynativeExecutor::set_grad_flag, py::arg("flag") = py::bool_(false),
"Executor set grad flag.")
.def("set_py_exe_path", &PynativeExecutor::set_py_exe_path,

View File

@ -366,6 +366,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
GradExecutorPtr grad_executor() const;
ForwardExecutorPtr forward_executor() const;
bool grad_flag() const;
void set_grad_flag(bool flag);
void set_graph_phase(const std::string &graph_phase);
void set_py_exe_path(const py::object &py_exe_path);

View File

@ -157,6 +157,8 @@ class _MindsporeFunctionExecutor:
dic = dict(zip(arg_names, args_list))
generate_name = self.fn.__module__ + "." + self.fn.__name__ + "." + self.fn.__code__.co_filename + "." + \
str(self.fn.__code__.co_firstlineno) + '.' + str(id(self.fn))
if _pynative_executor.grad_flag():
generate_name = generate_name + ".grad"
self.fn.__parse_method__ = method_name
# Add key with obj
@ -398,6 +400,9 @@ class _PynativeExecutor:
def set_graph_phase(self, phase):
self._executor.set_graph_phase(phase)
def grad_flag(self):
return self._executor.grad_flag()
def set_grad_flag(self, flag):
self._executor.set_grad_flag(flag)