!23192 Add grad flag info for ms_function phase
Merge pull request !23192 from JoyLvliang/add_graph_flag_info_for_phase
This commit is contained in:
commit
dbfcbae190
|
@ -656,9 +656,10 @@ bool EliminateForwardCNode(const ResourcePtr &res) {
|
|||
MS_EXCEPTION_IF_NULL(graph_executor);
|
||||
auto phase = graph_executor->phase();
|
||||
MS_LOG(DEBUG) << "The phase of current pipeline graph is: " << phase;
|
||||
// Export graph run in pynative mode no need to do this action.
|
||||
if (phase.find("export") != std::string::npos) {
|
||||
auto pynative_exec = pynative::PynativeExecutor::GetInstance();
|
||||
// Exporting graph in PyNative mode or only running forward process no need to do this action.
|
||||
auto pynative_exec = pynative::PynativeExecutor::GetInstance();
|
||||
if (phase.find("export") == 0 || !pynative_exec->grad_flag()) {
|
||||
MS_LOG(DEBUG) << "When exporting graph or only running forward process, no need to eliminate forward cnode.";
|
||||
auto grad_exec = pynative_exec->grad_executor();
|
||||
grad_exec->set_eliminate_forward(true);
|
||||
return true;
|
||||
|
@ -669,7 +670,6 @@ bool EliminateForwardCNode(const ResourcePtr &res) {
|
|||
MS_EXCEPTION_IF_NULL(res);
|
||||
auto ms_func_graph = res->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(ms_func_graph);
|
||||
auto pynative_exec = pynative::PynativeExecutor::GetInstance();
|
||||
auto grad_exec = pynative_exec->grad_executor();
|
||||
bool eliminate_forward = grad_exec->eliminate_forward();
|
||||
grad_exec->set_eliminate_forward(eliminate_forward && ms_func_graph->func_graphs_used().empty());
|
||||
|
|
|
@ -3006,6 +3006,8 @@ ForwardExecutorPtr PynativeExecutor::forward_executor() const {
|
|||
return forward_executor_;
|
||||
}
|
||||
|
||||
bool PynativeExecutor::grad_flag() const { return grad_executor()->grad_flag(); }
|
||||
|
||||
void PynativeExecutor::set_grad_flag(bool flag) { grad_executor()->set_grad_flag(flag); }
|
||||
|
||||
void PynativeExecutor::set_graph_phase(const std::string &graph_phase) {
|
||||
|
@ -3140,6 +3142,7 @@ REGISTER_PYBIND_DEFINE(PynativeExecutor_, ([](const py::module *m) {
|
|||
.def("execute_all_task", &PynativeExecutor::ExecuteAllTask, "clear all task")
|
||||
.def("__call__", &PynativeExecutor::Run, "pynative executor run grad graph.")
|
||||
.def("set_graph_phase", &PynativeExecutor::set_graph_phase, "pynative set graph phase")
|
||||
.def("grad_flag", &PynativeExecutor::grad_flag, "pynative grad flag")
|
||||
.def("set_grad_flag", &PynativeExecutor::set_grad_flag, py::arg("flag") = py::bool_(false),
|
||||
"Executor set grad flag.")
|
||||
.def("set_py_exe_path", &PynativeExecutor::set_py_exe_path,
|
||||
|
|
|
@ -366,6 +366,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
|
|||
GradExecutorPtr grad_executor() const;
|
||||
ForwardExecutorPtr forward_executor() const;
|
||||
|
||||
bool grad_flag() const;
|
||||
void set_grad_flag(bool flag);
|
||||
void set_graph_phase(const std::string &graph_phase);
|
||||
void set_py_exe_path(const py::object &py_exe_path);
|
||||
|
|
|
@ -157,6 +157,8 @@ class _MindsporeFunctionExecutor:
|
|||
dic = dict(zip(arg_names, args_list))
|
||||
generate_name = self.fn.__module__ + "." + self.fn.__name__ + "." + self.fn.__code__.co_filename + "." + \
|
||||
str(self.fn.__code__.co_firstlineno) + '.' + str(id(self.fn))
|
||||
if _pynative_executor.grad_flag():
|
||||
generate_name = generate_name + ".grad"
|
||||
self.fn.__parse_method__ = method_name
|
||||
|
||||
# Add key with obj
|
||||
|
@ -398,6 +400,9 @@ class _PynativeExecutor:
|
|||
def set_graph_phase(self, phase):
|
||||
self._executor.set_graph_phase(phase)
|
||||
|
||||
def grad_flag(self):
|
||||
return self._executor.grad_flag()
|
||||
|
||||
def set_grad_flag(self, flag):
|
||||
self._executor.set_grad_flag(flag)
|
||||
|
||||
|
|
Loading…
Reference in New Issue