diff --git a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc index 7bdf69bf883..31713a39bca 100644 --- a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc +++ b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc @@ -302,6 +302,7 @@ void DFunctor::ReplaceEquivdout(const CNodePtr &cnode, const CNodePtr &cnode_mor auto inst = pynative::PynativeExecutor::GetInstance(); inst->SaveOpForwardValue(input_value.second, input_value.first); auto input_value_node = NewValueNode(input_value.first); + input_value_node->set_has_new_value(true); manager->Replace(paras[i], input_value_node); } } diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index c9036dbf4ac..560c2ad44db 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -632,6 +632,9 @@ ValuePtr PynativeExecutor::GetForwardValue(const OpExecInfoPtr &op_exec_info) { MS_LOG(DEBUG) << "Get: " << op_exec_info->op_name << "(" << op << "), " << iter->second; return iter->second; } + if (!first_grad_step_) { + ++op_id_map_[id]; + } return nullptr; } @@ -979,7 +982,10 @@ void ClearPyNativeSession() { session = nullptr; } PynativeExecutor::~PynativeExecutor() { ClearRes(); } -PynativeExecutor::PynativeExecutor() { grad_flag_ = false; } +PynativeExecutor::PynativeExecutor() { + grad_flag_ = false; + first_grad_step_ = false; +} void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &args) { auto cell_id = GetCellId(cell, args); @@ -1000,6 +1006,8 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg cell_resource_map_[cell_id] = resource_; df_builder_ = std::make_shared(); MS_LOG(DEBUG) << "First new graph" << top_g_.get(); + first_grad_step_ = true; + top_graph_cells_.insert(cell_id); Pushp(); } else { Pushp(); @@ -1181,7 +1189,9 @@ void PynativeExecutor::EndGraphByOutId(const std::string &out_id, const py::obje MS_LOG(DEBUG) << "Current graph" << curr_g_->output()->DebugString(); resource_->manager()->AddFuncGraph(curr_g_); // custom bprop debug + bool need_replace_param = false; if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) { + need_replace_param = true; size_t par_number = py::tuple(parse::python_adapter::CallPyObjMethod(cell, "get_parameters")).size(); if (par_number > 0) { MS_LOG(EXCEPTION) << "When user defines the net bprop, there are " << par_number @@ -1195,6 +1205,15 @@ void PynativeExecutor::EndGraphByOutId(const std::string &out_id, const py::obje } } auto newfg = ad::Grad(curr_g_, resource_, curr_g_ == top_g_); + if (need_replace_param) { + auto params = newfg->parameters(); + auto manager = Manage({newfg}, false); + for (size_t i = 0; i < params.size(); i++) { + ValuePtr value = PyAttrValue(args[i]); + auto v_node = NewValueNode(value); + manager->Replace(params[i], v_node); + } + } graph_info_map_.erase(curr_g_); if (curr_g_ != top_g_) { Popp(); @@ -1355,6 +1374,9 @@ void PynativeExecutor::Clear(const std::string &flag) { ms_context->set_param(MS_CTX_ENABLE_PYNATIVE_INFER, false); } ConfigManager::GetInstance().ResetIterNum(); + if (top_graph_cells_.find(flag) != top_graph_cells_.end()) { + op_forward_map_.clear(); + } return; } @@ -1363,6 +1385,7 @@ void PynativeExecutor::Clear(const std::string &flag) { top_g_ = nullptr; df_builder_ = nullptr; curr_g_ = nullptr; + first_grad_step_ = false; graph_info_map_.clear(); op_id_map_.clear(); obj_to_forward_id_.clear(); @@ -1374,7 +1397,6 @@ void PynativeExecutor::Clean() { MS_LOG(DEBUG) << "Clean all res"; Clear(); grad_flag_ = false; - op_forward_map_.clear(); ad::CleanRes(); pipeline::ReclaimOptimizer(); } diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h index ee29353c21b..726e60033cd 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h @@ -24,6 +24,7 @@ #include #include #include +#include #include "pybind11/pybind11.h" #include "pybind11/numpy.h" @@ -145,6 +146,7 @@ class PynativeExecutor : public std::enable_shared_from_this { static ResourcePtr resource_; static int graph_id_; bool grad_flag_; + bool first_grad_step_; std::unordered_map graph_map_; std::unordered_map cell_graph_map_; std::unordered_map cell_resource_map_; @@ -158,6 +160,7 @@ class PynativeExecutor : public std::enable_shared_from_this { FuncGraphPtr df_builder_; FuncGraphPtr curr_g_; std::unordered_map prim_abs_list_; + std::set top_graph_cells_; }; using PynativeExecutorPtr = std::shared_ptr;